Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RetrievalAgent] Add base retrieval agent #258

Open
wants to merge 31 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions erniebot-agent/src/erniebot_agent/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
FunctionAgentWithRetrievalScoreTool,
FunctionAgentWithRetrievalTool,
)
from erniebot_agent.agents.retrieval_agent import RetrievalAgent
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def _run(self, prompt: str, files: Optional[List[File]] = None) -> AgentRe
logger.info(
f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}"
)
return await super()._run(prompt)
return await super()._run(prompt, files)

async def _maybe_retrieval(
self,
Expand Down Expand Up @@ -359,24 +359,12 @@ async def _run(self, prompt: str, files: Optional[List[File]] = None) -> AgentRe
return response
num_steps_taken += 1
response = self._create_stopped_response(chat_history, steps_taken)
# while num_steps_taken < self.max_steps:
# curr_step_output = await self._step(
# next_step_input, chat_history, actions_taken, files_involved
# )
# if curr_step_output is None:
# response = self._create_finished_response(chat_history, actions_taken, files_involved)
# self.memory.add_message(chat_history[0])
# self.memory.add_message(chat_history[-1])
# return response
# num_steps_taken += 1
# # response = self._create_stopped_response(chat_history, actions_taken, files_involved)
# self._create_stopped_response(chat_history, steps_taken)
return response
else:
logger.info(
f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}"
)
return await super()._run(prompt)
return await super()._run(prompt, files)

async def _maybe_retrieval(
self,
Expand Down
133 changes: 133 additions & 0 deletions erniebot-agent/src/erniebot_agent/agents/retrieval_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import json
from typing import List, Optional

from erniebot_agent.agents.agent import Agent
from erniebot_agent.agents.schema import AgentResponse, AgentStep
from erniebot_agent.file.base import File
from erniebot_agent.memory.messages import HumanMessage, Message, SystemMessage
from erniebot_agent.prompt import PromptTemplate

QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,每个子问题必须足够简单,要求:
1.严格按照【JSON格式】的形式输出:{'子问题1':'具体子问题1','子问题2':'具体子问题2'}
问题:{{prompt}} 子问题:"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要给几个few shots吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了一个few shot retriever,更通用



OPENAI_RAG_PROMPT = """检索结果:
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
{% for doc in documents %}
第{{loop.index}}个段落: {{doc['content']}}
{% endfor %}
检索语句: {{query}}
请根据以上检索结果回答检索语句的问题"""


CONTENT_COMPRESSOR = """针对以下问题和背景,提取背景中与回答问题相关的任何部分,并原样保留。如果背景中没有与问题相关的部分,则返回{no_output_str}。

记住,不要编辑提取的背景部分。

> 问题: {{query}}
> 背景:
>>>
{{context}}
>>>
提取的相关部分:"""


class RetrievalAgent(Agent):
def __init__(
self, knowledge_base, top_k: int = 2, threshold: float = 0.1, use_extractor: bool = False, **kwargs
):
super().__init__(**kwargs)
self.top_k = top_k
self.threshold = threshold
self.system_message = SystemMessage(content="您是一个智能体,旨在回答有关知识库的查询。请始终使用提供的工具回答问题。不要依赖先验知识。")
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
self.query_transform = PromptTemplate(QUERY_DECOMPOSITION, input_variables=["prompt"])
self.knowledge_base = knowledge_base
self.rag_prompt = PromptTemplate(OPENAI_RAG_PROMPT, input_variables=["documents", "query"])
self.use_extractor = use_extractor
self.extractor = PromptTemplate(CONTENT_COMPRESSOR, input_variables=["context", "query"])
w5688414 marked this conversation as resolved.
Show resolved Hide resolved

async def _run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 为啥不写在_run里面呢
  2. log需要加,同时steps需要遵守, 要不然返回最后的response信息不足

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_llm里面加了logger日志

steps_taken: List[AgentStep] = []
return await self.plan_and_execute(prompt, steps_taken)

async def plan_and_execute(self, prompt, actions_taken):
step_input = HumanMessage(content=self.query_transform.format(prompt=prompt))
fake_chat_history: List[Message] = [step_input]
llm_resp = await self._run_llm(
messages=fake_chat_history,
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
functions=None,
system=self.system_message.content if self.system_message is not None else None,
)
output_message = llm_resp.message

json_results = self._parse_results(output_message.content)
sub_queries = json_results.values()
retrieval_results = []
if self.use_extractor:
for query in sub_queries:
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None)
docs = [item for item in documents["documents"]]
context = "\n".join([item["content"] for item in docs])
step_input = HumanMessage(content=self.extractor.format(query=prompt, context=context))
local_history: List[Message] = [step_input]
llm_resp = await self.run_llm(
messages=local_history,
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
functions=None,
system=self.system_message.content if self.system_message is not None else None,
)
# Parse Compressed results
output_message = llm_resp.message
compressed_data = docs[0]
compressed_data["sub_query"] = query
compressed_data["content"] = output_message.content
retrieval_results.append(compressed_data)

else:
duplicates = set()
for query in sub_queries:
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None)
docs = [item for item in documents["documents"]]
for doc in docs:
if doc["content"] not in duplicates:
duplicates.add(doc["content"])
retrieval_results.append(doc)
retrieval_results = retrieval_results[:3]
step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=retrieval_results))
chat_history: List[Message] = [step_input]
llm_resp = await self.run_llm(
messages=chat_history,
functions=None,
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
system=self.system_message.content if self.system_message is not None else None,
)

output_message = llm_resp.message
chat_history.append(output_message)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
response = self._create_finished_response(chat_history, actions_taken)
return response

def _parse_results(self, results):
left_index = results.find("{")
right_index = results.rfind("}")
if left_index == -1 or right_index == -1:
# if invalid json, use Functional Agent
return {"is_relevant": False}
try:
return json.loads(results[left_index : right_index + 1])
except Exception:
# if invalid json, use Functional Agent
return {"is_relevant": False}

def _create_finished_response(
self,
chat_history: List[Message],
steps: List[AgentStep],
) -> AgentResponse:
last_message = chat_history[-1]
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status="FINISHED",
)
11 changes: 10 additions & 1 deletion erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ class BaizhongSearchTool(Tool):
ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView

def __init__(
self, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None
self,
description,
db,
threshold: float = 0.0,
input_type=None,
output_type=None,
examples=None,
name=None,
) -> None:
super().__init__()
self.db = db
if name is not None:
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
self.name = name
self.description = description
self.few_shot_examples = []
if input_type is not None:
Expand Down