-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RetrievalAgent] Add base retrieval agent #258
Open
w5688414
wants to merge
31
commits into
PaddlePaddle:develop
Choose a base branch
from
w5688414:eb20
base: develop
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
0a6378c
Add retrieval agent
w5688414 4b74f9e
Add retrieval agent
w5688414 9a923a1
Add retrieval_agent.py
w5688414 f466020
Fix unitest
w5688414 2d167eb
Add content compressor
w5688414 2bdcfb4
Update
w5688414 4d20793
Add fewshot retriever
w5688414 36c8f4a
Merge branch 'develop' of https://github.com/PaddlePaddle/ERNIE-Bot-S…
w5688414 e1b083d
reformat
w5688414 6cb78d8
Add context retriever
w5688414 a028a62
Update
w5688414 de7728d
reformat
w5688414 0d847af
fix mypy error
w5688414 1143817
reformat
w5688414 9e214a8
Add unitests
w5688414 148264d
Update unitests
w5688414 57a0d7e
Update RetrievalStep
w5688414 db39621
Add unitests
w5688414 a45d01b
Update retriever
w5688414 3735fce
Update unitest
w5688414 e463c18
Update retrieval agent
w5688414 a24b61d
Update retrieval agent
w5688414 3800c4d
Update format
w5688414 012f0b1
Add retrieval agent tutorials
w5688414 b6a5789
Add langchain tools
w5688414 603a647
Add langchain retrieval tools
w5688414 afcb615
Update codestyle
w5688414 a5ddd02
Update notebook
w5688414 eaacf02
Update notebooks
w5688414 7122a26
Update format
w5688414 ca12ea1
reformat
w5688414 File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
erniebot-agent/src/erniebot_agent/agents/retrieval_agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import json | ||
from typing import List, Optional | ||
|
||
from erniebot_agent.agents.agent import Agent | ||
from erniebot_agent.agents.schema import AgentResponse, AgentStep | ||
from erniebot_agent.file.base import File | ||
from erniebot_agent.memory.messages import HumanMessage, Message, SystemMessage | ||
from erniebot_agent.prompt import PromptTemplate | ||
|
||
QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,每个子问题必须足够简单,要求: | ||
1.严格按照【JSON格式】的形式输出:{'子问题1':'具体子问题1','子问题2':'具体子问题2'} | ||
问题:{{prompt}} 子问题:""" | ||
|
||
|
||
OPENAI_RAG_PROMPT = """检索结果: | ||
w5688414 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{% for doc in documents %} | ||
第{{loop.index}}个段落: {{doc['content']}} | ||
{% endfor %} | ||
检索语句: {{query}} | ||
请根据以上检索结果回答检索语句的问题""" | ||
|
||
|
||
CONTENT_COMPRESSOR = """针对以下问题和背景,提取背景中与回答问题相关的任何部分,并原样保留。如果背景中没有与问题相关的部分,则返回{no_output_str}。 | ||
|
||
记住,不要编辑提取的背景部分。 | ||
|
||
> 问题: {{query}} | ||
> 背景: | ||
>>> | ||
{{context}} | ||
>>> | ||
提取的相关部分:""" | ||
|
||
|
||
class RetrievalAgent(Agent): | ||
def __init__( | ||
self, knowledge_base, top_k: int = 2, threshold: float = 0.1, use_extractor: bool = False, **kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.top_k = top_k | ||
self.threshold = threshold | ||
self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。") | ||
w5688414 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.query_transform = PromptTemplate(QUERY_DECOMPOSITION, input_variables=["prompt"]) | ||
self.knowledge_base = knowledge_base | ||
self.rag_prompt = PromptTemplate(OPENAI_RAG_PROMPT, input_variables=["documents", "query"]) | ||
self.use_extractor = use_extractor | ||
self.extractor = PromptTemplate(CONTENT_COMPRESSOR, input_variables=["context", "query"]) | ||
w5688414 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
async def _run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. run_llm里面加了logger日志 |
||
steps_taken: List[AgentStep] = [] | ||
return await self.plan_and_execute(prompt, steps_taken) | ||
|
||
async def plan_and_execute(self, prompt, actions_taken): | ||
step_input = HumanMessage(content=self.query_transform.format(prompt=prompt)) | ||
fake_chat_history: List[Message] = [step_input] | ||
llm_resp = await self._run_llm( | ||
messages=fake_chat_history, | ||
w5688414 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
functions=None, | ||
system=self.system_message.content if self.system_message is not None else None, | ||
) | ||
output_message = llm_resp.message | ||
|
||
json_results = self._parse_results(output_message.content) | ||
sub_queries = json_results.values() | ||
retrieval_results = [] | ||
if self.use_extractor: | ||
for query in sub_queries: | ||
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None) | ||
docs = [item for item in documents["documents"]] | ||
context = "\n".join([item["content"] for item in docs]) | ||
step_input = HumanMessage(content=self.extractor.format(query=prompt, context=context)) | ||
local_history: List[Message] = [step_input] | ||
llm_resp = await self.run_llm( | ||
messages=local_history, | ||
w5688414 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
functions=None, | ||
system=self.system_message.content if self.system_message is not None else None, | ||
) | ||
# Parse Compressed results | ||
output_message = llm_resp.message | ||
compressed_data = docs[0] | ||
compressed_data["sub_query"] = query | ||
compressed_data["content"] = output_message.content | ||
retrieval_results.append(compressed_data) | ||
|
||
else: | ||
duplicates = set() | ||
for query in sub_queries: | ||
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None) | ||
docs = [item for item in documents["documents"]] | ||
for doc in docs: | ||
if doc["content"] not in duplicates: | ||
duplicates.add(doc["content"]) | ||
retrieval_results.append(doc) | ||
retrieval_results = retrieval_results[:3] | ||
step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=retrieval_results)) | ||
chat_history: List[Message] = [step_input] | ||
llm_resp = await self.run_llm( | ||
messages=chat_history, | ||
functions=None, | ||
w5688414 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
system=self.system_message.content if self.system_message is not None else None, | ||
) | ||
|
||
output_message = llm_resp.message | ||
chat_history.append(output_message) | ||
self.memory.add_message(chat_history[0]) | ||
self.memory.add_message(chat_history[-1]) | ||
response = self._create_finished_response(chat_history, actions_taken) | ||
return response | ||
|
||
def _parse_results(self, results): | ||
left_index = results.find("{") | ||
right_index = results.rfind("}") | ||
if left_index == -1 or right_index == -1: | ||
# if invalid json, use Functional Agent | ||
return {"is_relevant": False} | ||
try: | ||
return json.loads(results[left_index : right_index + 1]) | ||
except Exception: | ||
# if invalid json, use Functional Agent | ||
return {"is_relevant": False} | ||
|
||
def _create_finished_response( | ||
self, | ||
chat_history: List[Message], | ||
steps: List[AgentStep], | ||
) -> AgentResponse: | ||
last_message = chat_history[-1] | ||
return AgentResponse( | ||
text=last_message.content, | ||
chat_history=chat_history, | ||
steps=steps, | ||
status="FINISHED", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要给几个few shots吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加了一个few shot retriever,更通用