Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RetrievalAgent] Add base retrieval agent #258

Open
wants to merge 31 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "python310",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "py310"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -496,7 +496,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "py310_openai"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -412,7 +412,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down
659 changes: 659 additions & 0 deletions erniebot-agent/cookbook/retrieval_agent.ipynb

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions erniebot-agent/cookbook/tools_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "eb-sdk",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -307,15 +307,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "8a19f367f79553e5cd49921fbfd8af2792f58f47b1c0c637c2b65217dfab81ed"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions erniebot-agent/src/erniebot_agent/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
FunctionAgentWithRetrievalScoreTool,
FunctionAgentWithRetrievalTool,
)
from erniebot_agent.agents.retrieval_agent import RetrievalAgent
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
_logger.info(
f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}"
)
return await super()._run(prompt)
return await super()._run(prompt, files)

async def _maybe_retrieval(
self,
Expand Down Expand Up @@ -358,24 +358,12 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
return response
num_steps_taken += 1
response = self._create_stopped_response(chat_history, steps_taken)
# while num_steps_taken < self.max_steps:
# curr_step_output = await self._step(
# next_step_input, chat_history, actions_taken, files_involved
# )
# if curr_step_output is None:
# response = self._create_finished_response(chat_history, actions_taken, files_involved)
# self.memory.add_message(chat_history[0])
# self.memory.add_message(chat_history[-1])
# return response
# num_steps_taken += 1
# # response = self._create_stopped_response(chat_history, actions_taken, files_involved)
# self._create_stopped_response(chat_history, steps_taken)
return response
else:
_logger.info(
f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}"
)
return await super()._run(prompt)
return await super()._run(prompt, files)

async def _maybe_retrieval(
self,
Expand Down
191 changes: 191 additions & 0 deletions erniebot-agent/src/erniebot_agent/agents/retrieval_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import json
from typing import List, Optional, Sequence

from erniebot_agent.agents.agent import Agent
from erniebot_agent.agents.schema import AgentResponse, AgentStep
from erniebot_agent.file import File
from erniebot_agent.memory.messages import HumanMessage, Message
from erniebot_agent.prompt import PromptTemplate
from erniebot_agent.tools.langchain_retrieval_tool import LangChainRetrievalTool

ZERO_SHOT_QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,子问题的数量不超过5个,每个子问题必须足够简单。要求:
1.严格按照【JSON格式】的形式输出:{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"}
问题:{{query}} 子问题:"""

FEW_SHOT_QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,子问题的数量不超过5个,每个子问题必须足够简单。要求:
严格按照【JSON格式】的形式输出:{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"}
示例:
##
{% for doc in documents %}
问题:{{doc['content']}}
子问题:{{doc['sub_queries']}}
{% endfor %}
##
问题:{{query}}
子问题:
"""

RAG_PROMPT = """检索结果:
{% for doc in documents %}
第{{loop.index}}个段落: {{doc['content']}}
{% endfor %}
检索语句: {{query}}
请根据以上检索结果回答检索语句的问题"""


CONTENT_COMPRESSOR = """针对以下问题和背景,提取背景中与回答问题相关的任何部分,并原样保留。如果背景中没有与问题相关的部分,则返回{no_output_str}。
记住,不要编辑提取的背景部分。
> 问题: {{query}}
> 背景:
>>>
{{context}}
>>>
提取的相关部分:"""

CONTEXT_QUERY_DECOMPOSITION = """
{{context}} 请把下面的问题分解成子问题,子问题的数量不超过5个,每个子问题必须足够简单。要求:
严格按照【JSON格式】的形式输出:{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"}
问题:{{query}} 子问题:
"""


class RetrievalAgent(Agent):
def __init__(
self,
knowledge_base,
few_shot_retriever: Optional[LangChainRetrievalTool] = None,
context_retriever: Optional[LangChainRetrievalTool] = None,
top_k: int = 2,
threshold: float = 0.1,
use_compressor: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.top_k = top_k
self.threshold = threshold

self.knowledge_base = knowledge_base
self.few_shot_retriever = few_shot_retriever
self.context_retriever = context_retriever
if self.few_shot_retriever and self.context_retriever:
raise Exception("Few shot retriever and context retriever shouldn't be used simutaneously")
if few_shot_retriever:
self.query_transform = PromptTemplate(
FEW_SHOT_QUERY_DECOMPOSITION, input_variables=["query", "documents"]
)
elif self.context_retriever:
self.query_transform = PromptTemplate(
CONTEXT_QUERY_DECOMPOSITION, input_variables=["context", "query"]
)
else:
self.query_transform = PromptTemplate(ZERO_SHOT_QUERY_DECOMPOSITION, input_variables=["query"])
self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"])
self.use_compressor = use_compressor
self.compressor = PromptTemplate(CONTENT_COMPRESSOR, input_variables=["context", "query"])

async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
steps_taken: List[AgentStep] = []
if self.few_shot_retriever:
# Get few shot examples
docs = await self.few_shot_retriever(prompt, 3)
few_shots = []
for doc in docs["documents"]:
few_shots.append(
{
"content": doc["content"],
"sub_queries": doc["meta"]["sub_queries"],
"score": doc["score"],
}
)
steps_input = HumanMessage(
content=self.query_transform.format(query=prompt, documents=few_shots)
)
steps_taken.append(
AgentStep(info={"query": prompt, "name": "few shot retriever"}, result=few_shots)
)
elif self.context_retriever:
res = await self.context_retriever(prompt, 3)
context = [item["content"] for item in res["documents"]]
steps_input = HumanMessage(
content=self.query_transform.format(query=prompt, context="\n".join(context))
)
steps_taken.append(AgentStep(info={"query": prompt, "name": "context retriever"}, result=res))
else:
steps_input = HumanMessage(content=self.query_transform.format(query=prompt))
Comment on lines +113 to +114
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么时候会走进最后这个else branch呢?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zero shot,完全靠大模型自己的能力进行子query分解的时候

# Query planning
llm_resp = await self.run_llm(
messages=[steps_input],
)
output_message = llm_resp.message
json_results = self._parse_results(output_message.content)
sub_queries = json_results.values()
# Sub query execution
retrieval_results = await self.execute(sub_queries, steps_taken)

# Answer generation
step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=retrieval_results))
chat_history: List[Message] = [step_input]
llm_resp = await self.run_llm(
messages=chat_history,
)

output_message = llm_resp.message
chat_history.append(output_message)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
response = self._create_finished_response(chat_history, steps_taken)
return response

async def execute(self, sub_queries, steps_taken: List[AgentStep]):
retrieval_results = []
if self.use_compressor:
for idx, query in enumerate(sub_queries):
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None)
docs = [item for item in documents["documents"]]
context = "\n".join([item["content"] for item in docs])
llm_resp = await self.run_llm(
messages=[HumanMessage(content=self.compressor.format(query=query, context=context))]
)
# Parse Compressed results
output_message = llm_resp.message
compressed_data = docs[0]
compressed_data["sub_query"] = query
compressed_data["content"] = output_message.content
retrieval_results.append(compressed_data)
steps_taken.append(
AgentStep(
info={"query": query, "name": f"sub query compressor {idx}"}, result=compressed_data
)
)
else:
duplicates = set()
for idx, query in enumerate(sub_queries):
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None)
docs = [item for item in documents["documents"]]
steps_taken.append(
AgentStep(info={"query": query, "name": f"sub query results {idx}"}, result=documents)
)
for doc in docs:
if doc["content"] not in duplicates:
duplicates.add(doc["content"])
retrieval_results.append(doc)
retrieval_results = retrieval_results[:3]
return retrieval_results

def _parse_results(self, results):
left_index = results.find("{")
right_index = results.rfind("}")
return json.loads(results[left_index : right_index + 1])

def _create_finished_response(
self,
chat_history: List[Message],
steps: List[AgentStep],
) -> AgentResponse:
last_message = chat_history[-1]
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status="FINISHED",
)
10 changes: 9 additions & 1 deletion erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@ class BaizhongSearchTool(Tool):
ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView

def __init__(
self, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None
self,
description,
db,
threshold: float = 0.0,
input_type=None,
output_type=None,
examples=None,
name=None,
) -> None:
super().__init__()
self.db = db
self.name = name
self.description = description
self.few_shot_examples = []
if input_type is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Dict, List, Optional

from pydantic import Field

from erniebot_agent.tools.schema import ToolParameterView

from .base import Tool


class LangChainRetrievalToolInputView(ToolParameterView):
query: str = Field(description="查询语句")
top_k: int = Field(description="返回结果数量")


class SearchResponseDocument(ToolParameterView):
title: str = Field(description="检索结果的标题")
document: str = Field(description="检索结果的内容")


class LangChainRetrievalToolOutputView(ToolParameterView):
documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落")


class LangChainRetrievalTool(Tool):
description: str = "在知识库中检索与用户输入query相关的段落"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这里没有设置 InputView 和 OutputView。

Copy link
Collaborator Author

@w5688414 w5688414 Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

langchain里面有metadata字段,是一个Dict,用于存储元数据,不太好实例化成pydantic的形式


def __init__(
self,
db,
threshold: float = 0.0,
input_type=None,
output_type=None,
return_meta_data: bool = True,
) -> None:
super().__init__()
self.db = db
self.return_meta_data = return_meta_data
if input_type is not None:
self.input_type = input_type
if output_type is not None:
self.ouptut_type = output_type
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
documents = self.db.similarity_search_with_relevance_scores(query, top_k)
docs = []
for doc, score in documents:
if score > self.threshold:
new_doc = {"content": doc.page_content, "score": score}
if self.return_meta_data:
new_doc["meta"] = doc.metadata
docs.append(new_doc)

return {"documents": docs}