Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RetrievalAgent] Add base retrieval agent #258

Open
wants to merge 31 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions erniebot-agent/src/erniebot_agent/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
FunctionAgentWithRetrievalScoreTool,
FunctionAgentWithRetrievalTool,
)
from erniebot_agent.agents.retrieval_agent import RetrievalAgent
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
_logger.info(
f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}"
)
return await super()._run(prompt)
return await super()._run(prompt, files)

async def _maybe_retrieval(
self,
Expand Down Expand Up @@ -358,24 +358,12 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
return response
num_steps_taken += 1
response = self._create_stopped_response(chat_history, steps_taken)
# while num_steps_taken < self.max_steps:
# curr_step_output = await self._step(
# next_step_input, chat_history, actions_taken, files_involved
# )
# if curr_step_output is None:
# response = self._create_finished_response(chat_history, actions_taken, files_involved)
# self.memory.add_message(chat_history[0])
# self.memory.add_message(chat_history[-1])
# return response
# num_steps_taken += 1
# # response = self._create_stopped_response(chat_history, actions_taken, files_involved)
# self._create_stopped_response(chat_history, steps_taken)
return response
else:
_logger.info(
f"Irrelevant retrieval results. Fallbacking to FunctionAgent for the query: {prompt}"
)
return await super()._run(prompt)
return await super()._run(prompt, files)

async def _maybe_retrieval(
self,
Expand Down
203 changes: 203 additions & 0 deletions erniebot-agent/src/erniebot_agent/agents/retrieval_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import json
from typing import List, Optional, Sequence

from erniebot_agent.agents.agent import Agent
from erniebot_agent.agents.schema import AgentResponse, AgentStep, RetrievalStep
from erniebot_agent.file import File
from erniebot_agent.memory.messages import HumanMessage, Message
from erniebot_agent.prompt import PromptTemplate

ZERO_SHOT_QUERY_DECOMPOSITION = """请把下面的问题分解成子问题,每个子问题必须足够简单,要求:
1.严格按照【JSON格式】的形式输出:{'sub_query_1':'具体子问题1','sub_query_2':'具体子问题2'}
问题:{{query}} 子问题:"""

FEW_SHOT_QUERT_DECOMPOSITION = """请把下面的问题分解成子问题,
严格按照【JSON格式】的形式输出:{'sub_query_1':'具体子问题1','sub_query_2':'具体子问题2'}。
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
示例:
##
{% for doc in documents %}
问题:{{doc['content']}}
子问题:{{doc['sub_queries']}}
{% endfor %}
##
问题:{{query}}
子问题:
"""

RAG_PROMPT = """检索结果:
{% for doc in documents %}
第{{loop.index}}个段落: {{doc['content']}}
{% endfor %}
检索语句: {{query}}
请根据以上检索结果回答检索语句的问题"""


CONTENT_COMPRESSOR = """针对以下问题和背景,提取背景中与回答问题相关的任何部分,并原样保留。如果背景中没有与问题相关的部分,则返回{no_output_str}。

记住,不要编辑提取的背景部分。

> 问题: {{query}}
> 背景:
>>>
{{context}}
>>>
提取的相关部分:"""

CONTEXT_PLANNING = """
{{context}} 请根据上述背景信息把下面的问题分解成子问题,每个子问题必须足够简单,要求:
严格按照【JSON格式】的形式输出:{'sub_query_1':'具体子问题1','sub_query_2':'具体子问题2'}。
问题:{{query}} 子问题:
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
"""


class FaissFewShotSearch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FaissFewShotSearch和FaissAbstractSearch 这2个类都以Faiss打头,但是没看出和faiss有任何关系. 唯一有关系的是self.db. similarity_search_with_relevance_scores这个函数, 但是这个函数应该是任何vector db实现都有的,也不仅仅和Faiss相关

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,这个similarity_search_with_relevance_scores函数是langchain独有的

def __init__(self, db):
self.db = db

def search(self, query: str, top_k: int = 10, **kwargs):
docs = self.db.similarity_search_with_relevance_scores(query, top_k)
retrieval_results = []
for doc, score in docs:
retrieval_results.append(
{"content": doc.page_content, "sub_queries": doc.metadata["sub_queries"], "score": score}
)
return retrieval_results


class FaissAbstractSearch:
def __init__(self, db):
self.db = db

def search(self, query: str, top_k: int = 10, **kwargs):
docs = self.db.similarity_search_with_relevance_scores(query, top_k)
retrieval_results = []
for doc, score in docs:
retrieval_results.append({"content": doc.page_content, "score": score})
return retrieval_results


class RetrievalAgent(Agent):
def __init__(
self,
knowledge_base,
few_shot_retriever: Optional[FaissFewShotSearch] = None,
context_retriever: Optional[FaissAbstractSearch] = None,
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
top_k: int = 2,
threshold: float = 0.1,
use_compressor: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.top_k = top_k
self.threshold = threshold

self.knowledge_base = knowledge_base
self.few_shot_retriever = few_shot_retriever
self.context_retriever = context_retriever
if self.few_shot_retriever and self.context_retriever:
raise Exception("Few shot retriever and context retriever shouldn't be used simutaneously")
if few_shot_retriever:
self.query_transform = PromptTemplate(
FEW_SHOT_QUERT_DECOMPOSITION, input_variables=["query", "documents"]
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
self.query_transform = PromptTemplate(ZERO_SHOT_QUERY_DECOMPOSITION, input_variables=["query"])
self.rag_prompt = PromptTemplate(RAG_PROMPT, input_variables=["documents", "query"])
self.use_compressor = use_compressor
self.compressor = PromptTemplate(CONTENT_COMPRESSOR, input_variables=["context", "query"])
self.context_planning = PromptTemplate(CONTEXT_PLANNING, input_variables=["context", "query"])

async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
steps_taken: List[AgentStep] = []
if self.few_shot_retriever is not None:
# Get few shot examples
few_shots = self.few_shot_retriever.search(prompt, 3)
steps_input = HumanMessage(
content=self.query_transform.format(query=prompt, documents=few_shots)
)
steps_taken.append(RetrievalStep(name="few shot retriever", info=prompt, result=few_shots))
elif self.context_retriever:
res = self.context_retriever.search(prompt, 3)

context = [item["content"] for item in res]
steps_input = HumanMessage(
content=self.context_planning.format(query=prompt, context="\n".join(context))
)
steps_taken.append(RetrievalStep(name="context retriever", info=prompt, result=res))
else:
steps_input = HumanMessage(content=self.query_transform.format(query=prompt))
Comment on lines +113 to +114
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么时候会走进最后这个else branch呢?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zero shot,完全靠大模型自己的能力进行子query分解的时候

# Query planning
llm_resp = await self.run_llm(
messages=[steps_input],
)
output_message = llm_resp.message
json_results = self._parse_results(output_message.content)
sub_queries = json_results.values()
# Sub query execution
retrieval_results = await self.execute(sub_queries, steps_taken)

# Answer generation
step_input = HumanMessage(content=self.rag_prompt.format(query=prompt, documents=retrieval_results))
chat_history: List[Message] = [step_input]
llm_resp = await self.run_llm(
messages=chat_history,
)

output_message = llm_resp.message
chat_history.append(output_message)
self.memory.add_message(chat_history[0])
self.memory.add_message(chat_history[-1])
response = self._create_finished_response(chat_history, steps_taken)
return response

async def execute(self, sub_queries, steps_taken: List[AgentStep]):
retrieval_results = []
if self.use_compressor:
for idx, query in enumerate(sub_queries):
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None)
docs = [item for item in documents["documents"]]
context = "\n".join([item["content"] for item in docs])
llm_resp = await self.run_llm(
messages=[HumanMessage(content=self.compressor.format(query=query, context=context))]
)
# Parse Compressed results
output_message = llm_resp.message
compressed_data = docs[0]
compressed_data["sub_query"] = query
compressed_data["content"] = output_message.content
retrieval_results.append(compressed_data)
steps_taken.append(
RetrievalStep(name=f"sub query compressor {idx}", info=query, result=compressed_data)
)
else:
duplicates = set()
for idx, query in enumerate(sub_queries):
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None)
docs = [item for item in documents["documents"]]
steps_taken.append(
RetrievalStep(name=f"sub query results {idx}", info=query, result=documents)
)
for doc in docs:
if doc["content"] not in duplicates:
duplicates.add(doc["content"])
retrieval_results.append(doc)
retrieval_results = retrieval_results[:3]
return retrieval_results

def _parse_results(self, results):
left_index = results.find("{")
right_index = results.rfind("}")
return json.loads(results[left_index : right_index + 1])

def _create_finished_response(
self,
chat_history: List[Message],
steps: List[AgentStep],
) -> AgentResponse:
last_message = chat_history[-1]
return AgentResponse(
text=last_message.content,
chat_history=chat_history,
steps=steps,
status="FINISHED",
)
7 changes: 7 additions & 0 deletions erniebot-agent/src/erniebot_agent/agents/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ class AgentStep(Generic[_IT, _RT]):
result: _RT


@dataclass
class RetrievalStep(AgentStep):
"""A step taken by an agent."""

name: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果只是为了加一个name字段,就没有必要去新开一个class了。直接将name字段放入info就可以了,info是一个dict

Copy link
Collaborator Author

@w5688414 w5688414 Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,但我有个疑问:是不是可以给每个step加一个别名,当成默认属性,然后单独拿出来?



@dataclass
class AgentStepWithFiles(AgentStep[_IT, _RT]):
"""A step taken by an agent involving file input and output."""
Expand Down
11 changes: 10 additions & 1 deletion erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ class BaizhongSearchTool(Tool):
ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView

def __init__(
self, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None
self,
description,
db,
threshold: float = 0.0,
input_type=None,
output_type=None,
examples=None,
name=None,
) -> None:
super().__init__()
self.db = db
if name is not None:
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
self.name = name
self.description = description
self.few_shot_examples = []
if input_type is not None:
Expand Down