Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sparse search] Add sparse search agent #314

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 60 additions & 0 deletions erniebot-agent/applications/search_agent/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
from typing import Dict, List, Optional

import requests

_logger = logging.getLogger(__name__)


class CustomSearch:
def __init__(self, base_url, outId, key, access_token=None):
self._base_url = base_url
self.outId = outId
self.key = key
self.access_token = access_token
if self.access_token is None:
self.access_token = self._get_ticket()

def _get_ticket(
self,
):
res = requests.post(
f"{self._base_url}/api/account/getticket?outId={self.outId}&key={self.key}",
)
result = res.json()
return result["Data"]

def _get_authorization_headers(self, access_token: Optional[str]) -> Dict:
"""
Initialize a dictionary for HTTP headers with Content-Type set to application/json.

Args:
access_token (str): The AIStudio access_token.

Returns:
Dict[str, Any]: A dictionary containing HTTP headers information.
"""
headers = {"Content-Type": "application/json"}
if access_token is None:
_logger.warning("access_token is NOT provided, this may cause 403 HTTP error..")
else:
headers["Authorization"] = f"token {access_token}"
return headers

def search(self, searchKeywords: str, identifier: str = "U", top_k: int = 10, **kwargs) -> List[Dict]:
data = {
"pageSize": top_k,
"searchKeywords": searchKeywords,
"identifier": identifier,
}
data.update(kwargs)
res = requests.post(
f"{self._base_url}/api/search/getarticlesearchresult",
headers=self._get_authorization_headers(access_token=self.access_token),
params=data,
)
if res.status_code == 200:
result = res.json()
return result
else:
raise Exception(f"Error: {res.text}")
102 changes: 102 additions & 0 deletions erniebot-agent/applications/search_agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from collections import defaultdict
from typing import Dict, List, Optional

from retrieval import CustomSearch
from utils import JsonUtil

from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, Message
from erniebot_agent.prompt import PromptTemplate

KEYWORDS = """请提取文本中涉及的概念,年份,数量等实体,并将输出到json列表中.
json格式是:{"keywords": ["关键词1","关键词2"]}。
{% if documents|length > 0 %}
示例:
##
{% for doc in documents %}
文本:{{doc['query']}}
输出:{{doc['data']}}
{% endfor %}
##
{% endif %}
文本:{{query}}
输出:
"""


class SparseSearchAgent(JsonUtil):
def __init__(
self,
llm: BaseERNIEBot,
retrieval: CustomSearch,
few_shots: List[Dict] = [],
join_mode: str = "reciprocal_rank_fusion",
top_k: int = 10,
):
self.llm = llm
self.prompt = PromptTemplate(KEYWORDS, input_variables=["query", "documents"])
self.retrieval = retrieval
self.few_shots = few_shots
self.join_mode = join_mode
self.query_repeat = 5
self.top_k_join = top_k

async def run(self, query: str) -> Dict:
agent_resp = await self._run(query)
return agent_resp

async def _run(self, query, top_k_join: Optional[int] = None):
if top_k_join:
self.top_k_join = top_k_join
# TODO(wugaosheng): Add llm output to generate search keywords
new_query = await self.query_expansion(query)
# Add original query to avoid low-quality expanded keywords
retrieval_results = []
expanded_queries = [query, new_query]
for query in expanded_queries:
results = self.retrieval.search(query)
retrieval_results.append(results["Data"]["results"])

# Ranking by fusion score
if self.join_mode == "reciprocal_rank_fusion":
scores_map = self._calculate_rrf(retrieval_results)
sorted_docs = sorted(scores_map.items(), key=lambda d: d[1], reverse=True)
results = [inp for inp in retrieval_results]
document_map = {doc["id"]: doc for result in results for doc in result}
docs = []
for title, score in sorted_docs[: self.top_k_join]:
doc = document_map[title]
doc["score"] = score
docs.append(doc)
return docs

async def query_expansion(self, query):
content = self.prompt.format(query=query, documents=self.few_shots)
messages: List[Message] = [HumanMessage(content)]
response = await self.llm.chat(messages, response_format="json_object", enable_human_clarify=True)
keyword_result = self.parse_json(response.content)
if "pageSize" in keyword_result:
self.top_k_join = int(keyword_result["pageSize"])
new_query = self.create_new_query(keyword_result["keywords"])
return new_query

def create_new_query(self, keywords):
new_query = ",".join(keywords)
# Repeat query to increase term weights
new_query = new_query * self.query_repeat
return new_query

def _calculate_rrf(self, results):
"""
Calculates the reciprocal rank fusion.
The constant K is set to 61 (60 was suggested by the original paper,
plus 1 as python lists are 0-based and the paper used 1-based ranking).
"""
K = 61

scores_map = defaultdict(int)
for result in results:
for rank, doc in enumerate(result):
scores_map[doc["id"]] += 1 / (K + rank)

return scores_map
15 changes: 15 additions & 0 deletions erniebot-agent/applications/search_agent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import json

from langchain.output_parsers.json import parse_json_markdown


class JsonUtil:
def parse_json(self, json_str, start_indicator: str = "{", end_indicator: str = "}"):
if start_indicator == "{":
response = parse_json_markdown(json_str)
else:
start_idx = json_str.index(start_indicator)
end_idx = json_str.rindex(end_indicator)
corrected_data = json_str[start_idx : end_idx + 1]
response = json.loads(corrected_data)
return response