-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RetrievalAgent] Add base retrieval agent #258
base: develop
Are you sure you want to change the base?
Changes from all commits
0a6378c
4b74f9e
9a923a1
f466020
2d167eb
2bdcfb4
4d20793
36c8f4a
e1b083d
6cb78d8
a028a62
de7728d
0d847af
1143817
9e214a8
148264d
57a0d7e
db39621
a45d01b
3735fce
e463c18
a24b61d
3800c4d
012f0b1
b6a5789
603a647
afcb615
a5ddd02
eaacf02
7122a26
ca12ea1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import json | ||
from typing import List, Optional, Sequence | ||
|
||
from erniebot_agent.agents.agent import Agent | ||
from erniebot_agent.agents.schema import AgentResponse, AgentStep | ||
from erniebot_agent.file import File | ||
from erniebot_agent.memory.messages import HumanMessage, Message | ||
from erniebot_agent.prompt import PromptTemplate | ||
from erniebot_agent.tools.langchain_retrieval_tool import LangChainRetrievalTool | ||
|
||
ZERO_SHOT_QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,子问题的数量不超过5个,每个子问题必须足够简单。要求: | ||
1.严格按照【JSON格式】的形式输出:{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"} | ||
问题:{{query}} 子问题:""" | ||
|
||
FEW_SHOT_QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,子问题的数量不超过5个,每个子问题必须足够简单。要求: | ||
严格按照【JSON格式】的形式输出:{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"} | ||
示例: | ||
## | ||
{% for doc in documents %} | ||
问题:{{doc['content']}} | ||
子问题:{{doc['sub_queries']}} | ||
{% endfor %} | ||
## | ||
问题:{{query}} | ||
子问题: | ||
""" | ||
|
||
RAG_PROMPT = """检索结果: | ||
{% for doc in documents %} | ||
第{{loop.index}}个段落: {{doc['content']}} | ||
{% endfor %} | ||
检索语句: {{query}} | ||
请根据以上检索结果回答检索语句的问题""" | ||
|
||
|
||
CONTENT_COMPRESSOR = """针对以下问题和背景,提取背景中与回答问题相关的任何部分,并原样保留。如果背景中没有与问题相关的部分,则返回{no_output_str}。 | ||
记住,不要编辑提取的背景部分。 | ||
> 问题: {{query}} | ||
> 背景: | ||
>>> | ||
{{context}} | ||
>>> | ||
提取的相关部分:""" | ||
|
||
CONTEXT_QUERY_DECOMPOSITION = """ | ||
{{context}} 请把下面的问题分解成子问题,子问题的数量不超过5个,每个子问题必须足够简单。要求: | ||
严格按照【JSON格式】的形式输出:{"sub_query_1":"具体子问题1","sub_query_2":"具体子问题2"} | ||
问题:{{query}} 子问题: | ||
""" | ||
|
||
|
||
class RetrievalAgent(Agent): | ||
def __init__( | ||
self, | ||
knowledge_base, | ||
few_shot_retriever: Optional[LangChainRetrievalTool] = None, | ||
context_retriever: Optional[LangChainRetrievalTool] = None, | ||
top_k: int = 2, | ||
threshold: float = 0.1, | ||
use_compressor: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.top_k = top_k | ||
self.threshold = threshold | ||
|
||
self.knowledge_base = knowledge_base | ||
self.few_shot_retriever = few_shot_retriever | ||
self.context_retriever = context_retriever | ||
if self.few_shot_retriever and self.context_retriever: | ||
raise Exception("Few shot retriever and context retriever shouldn't be used simutaneously") | ||
if few_shot_retriever: | ||
self.query_transform = PromptTemplate( | ||
FEW_SHOT_QUERY_DECOMPOSITION, input_variables=["query", "documents"] | ||
) | ||
elif self.context_retriever: | ||
self.query_transform = PromptTemplate( | ||
CONTEXT_QUERY_DECOMPOSITION, input_variables=["context", "query"] | ||
) | ||
else: | ||
self.query_transform = PromptTemplate(ZERO_SHOT_QUERY_DECOMPOSITION, input_variables=["query"]) | ||
self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"]) | ||
self.use_compressor = use_compressor | ||
self.compressor = PromptTemplate(CONTENT_COMPRESSOR, input_variables=["context", "query"]) | ||
|
||
async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: | ||
steps_taken: List[AgentStep] = [] | ||
if self.few_shot_retriever: | ||
# Get few shot examples | ||
docs = await self.few_shot_retriever(prompt, 3) | ||
few_shots = [] | ||
for doc in docs["documents"]: | ||
few_shots.append( | ||
{ | ||
"content": doc["content"], | ||
"sub_queries": doc["meta"]["sub_queries"], | ||
"score": doc["score"], | ||
} | ||
) | ||
steps_input = HumanMessage( | ||
content=self.query_transform.format(query=prompt, documents=few_shots) | ||
) | ||
steps_taken.append( | ||
AgentStep(info={"query": prompt, "name": "few shot retriever"}, result=few_shots) | ||
) | ||
elif self.context_retriever: | ||
res = await self.context_retriever(prompt, 3) | ||
context = [item["content"] for item in res["documents"]] | ||
steps_input = HumanMessage( | ||
content=self.query_transform.format(query=prompt, context="\n".join(context)) | ||
) | ||
steps_taken.append(AgentStep(info={"query": prompt, "name": "context retriever"}, result=res)) | ||
else: | ||
steps_input = HumanMessage(content=self.query_transform.format(query=prompt)) | ||
# Query planning | ||
llm_resp = await self.run_llm( | ||
messages=[steps_input], | ||
) | ||
output_message = llm_resp.message | ||
json_results = self._parse_results(output_message.content) | ||
sub_queries = json_results.values() | ||
# Sub query execution | ||
retrieval_results = await self.execute(sub_queries, steps_taken) | ||
|
||
# Answer generation | ||
step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=retrieval_results)) | ||
chat_history: List[Message] = [step_input] | ||
llm_resp = await self.run_llm( | ||
messages=chat_history, | ||
) | ||
|
||
output_message = llm_resp.message | ||
chat_history.append(output_message) | ||
self.memory.add_message(chat_history[0]) | ||
self.memory.add_message(chat_history[-1]) | ||
response = self._create_finished_response(chat_history, steps_taken) | ||
return response | ||
|
||
async def execute(self, sub_queries, steps_taken: List[AgentStep]): | ||
retrieval_results = [] | ||
if self.use_compressor: | ||
for idx, query in enumerate(sub_queries): | ||
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None) | ||
docs = [item for item in documents["documents"]] | ||
context = "\n".join([item["content"] for item in docs]) | ||
llm_resp = await self.run_llm( | ||
messages=[HumanMessage(content=self.compressor.format(query=query, context=context))] | ||
) | ||
# Parse Compressed results | ||
output_message = llm_resp.message | ||
compressed_data = docs[0] | ||
compressed_data["sub_query"] = query | ||
compressed_data["content"] = output_message.content | ||
retrieval_results.append(compressed_data) | ||
steps_taken.append( | ||
AgentStep( | ||
info={"query": query, "name": f"sub query compressor {idx}"}, result=compressed_data | ||
) | ||
) | ||
else: | ||
duplicates = set() | ||
for idx, query in enumerate(sub_queries): | ||
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None) | ||
docs = [item for item in documents["documents"]] | ||
steps_taken.append( | ||
AgentStep(info={"query": query, "name": f"sub query results {idx}"}, result=documents) | ||
) | ||
for doc in docs: | ||
if doc["content"] not in duplicates: | ||
duplicates.add(doc["content"]) | ||
retrieval_results.append(doc) | ||
retrieval_results = retrieval_results[:3] | ||
return retrieval_results | ||
|
||
def _parse_results(self, results): | ||
left_index = results.find("{") | ||
right_index = results.rfind("}") | ||
return json.loads(results[left_index : right_index + 1]) | ||
|
||
def _create_finished_response( | ||
self, | ||
chat_history: List[Message], | ||
steps: List[AgentStep], | ||
) -> AgentResponse: | ||
last_message = chat_history[-1] | ||
return AgentResponse( | ||
text=last_message.content, | ||
chat_history=chat_history, | ||
steps=steps, | ||
status="FINISHED", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pydantic import Field | ||
|
||
from erniebot_agent.tools.schema import ToolParameterView | ||
|
||
from .base import Tool | ||
|
||
|
||
class LangChainRetrievalToolInputView(ToolParameterView): | ||
query: str = Field(description="查询语句") | ||
top_k: int = Field(description="返回结果数量") | ||
|
||
|
||
class SearchResponseDocument(ToolParameterView): | ||
title: str = Field(description="检索结果的标题") | ||
document: str = Field(description="检索结果的内容") | ||
|
||
|
||
class LangChainRetrievalToolOutputView(ToolParameterView): | ||
documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落") | ||
|
||
|
||
class LangChainRetrievalTool(Tool): | ||
description: str = "在知识库中检索与用户输入query相关的段落" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你这里没有设置 InputView 和 OutputView。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. langchain里面有metadata字段,是一个Dict,用于存储元数据,不太好实例化成pydantic的形式 |
||
|
||
def __init__( | ||
self, | ||
db, | ||
threshold: float = 0.0, | ||
input_type=None, | ||
output_type=None, | ||
return_meta_data: bool = True, | ||
) -> None: | ||
super().__init__() | ||
self.db = db | ||
self.return_meta_data = return_meta_data | ||
if input_type is not None: | ||
self.input_type = input_type | ||
if output_type is not None: | ||
self.ouptut_type = output_type | ||
self.threshold = threshold | ||
|
||
async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None): | ||
documents = self.db.similarity_search_with_relevance_scores(query, top_k) | ||
docs = [] | ||
for doc, score in documents: | ||
if score > self.threshold: | ||
new_doc = {"content": doc.page_content, "score": score} | ||
if self.return_meta_data: | ||
new_doc["meta"] = doc.metadata | ||
docs.append(new_doc) | ||
|
||
return {"documents": docs} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么时候会走进最后这个else branch呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zero shot,完全靠大模型自己的能力进行子query分解的时候