Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#11 from qingzhong1/eb7
Browse files Browse the repository at this point in the history
modify log
  • Loading branch information
w5688414 committed Jan 3, 2024
2 parents a14bae9 + 8484630 commit 9063a7e
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from typing import Optional

from tools.utils import erniebot_chat, json_correct, write_to_json
from tools.utils import ReportCallbackHandler, erniebot_chat, json_correct

from erniebot_agent.agents.agent import Agent
from erniebot_agent.prompt import PromptTemplate
Expand Down Expand Up @@ -54,27 +54,33 @@ def __init__(
name: str,
llm: str = "ernie-4.0",
system_message: Optional[str] = None,
config: list = [],
save_log_path=None,
callbacks=None,
):
self.name = name
self.system_message = system_message or self.DEFAULT_SYSTEM_MESSAGE
self.model = llm
self.config = config
self.save_log_path = save_log_path
self.prompt = PromptTemplate(" 草稿为:\n\n{{report}}", input_variables=["report"])
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
self._callback_manager = callbacks

async def _run(self, report):
await self._callback_manager.on_run_start(agent_name=self.name, query="")
messages = [
{
"role": "user",
"content": self.prompt.format(report=report),
}
]
if len(messages[0]["content"]) > 4800:
model = "ernie-longtext"
else:
model = "ernie-4.0"
while True:
try:
suggestions = erniebot_chat(
messages=messages, functions=eb_functions, model=self.model, system=self.system_message
messages=messages, functions=eb_functions, model=model, system=self.system_message
)
start_idx = suggestions.index("{")
end_idx = suggestions.rindex("}")
Expand All @@ -83,14 +89,9 @@ async def _run(self, report):
suggestions = json.loads(suggestions)
if "accept" not in suggestions and "notes" not in suggestions:
raise Exception("accept and notes key do not exist")

self.config.append(("编辑给出的建议", f"{suggestions}\n\n"))
self.save_log()
await self._callback_manager.on_run_end(self.name, suggestions)
return suggestions
except Exception as e:
logger.error(e)
self.config.append(("报错信息", e))
await self._callback_manager.on_run_error(self.name, error_information=str(e))
continue

def save_log(self):
write_to_json(self.save_log_path, self.config, mode="a")
24 changes: 12 additions & 12 deletions erniebot-agent/applications/erniebot_researcher/ranking_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import logging
from typing import Optional

from erniebot.prompt import PromptTemplate
from tools.utils import erniebot_chat, write_to_json
from tools.utils import ReportCallbackHandler, erniebot_chat

from erniebot_agent.agents.agent import Agent
from erniebot_agent.prompt import PromptTemplate

logger = logging.getLogger(__name__)

Expand All @@ -30,38 +30,37 @@ def __init__(
name: str,
ranking_tool,
system_message: Optional[str] = None,
config: list = [],
save_log_path=None,
callbacks=None,
is_reset=False,
) -> None:
self.name = name
self.system_message = system_message or self.DEFAULT_SYSTEM_MESSAGE

self.ranking = ranking_tool
self.config = config
self.save_log_path = save_log_path
self.is_reset = False
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
self._callback_manager = callbacks

async def _run(self, list_reports, query):
self._callback_manager.on_run_start(self.name, "")
reports = []
for item in list_reports:
if self.check_format(item):
reports.append(item)
if len(reports) == 0:
if self.is_reset:
self._callback_manager.on_run_end(self.name, "所有的report都不是markdown格式,重新生成report")
logger.info("所有的report都不是markdown格式,重新生成report")
return [], None
else:
reports = list_reports
best_report = await self.ranking(reports, query)
self.config.append(("最好的report", best_report))
if self.save_log_path:
self.save_log()
self._callback_manager.on_run_tool(self.ranking.description, best_report)
self._callback_manager.on_run_end(self.name, "")
return reports, best_report

def save_log(self):
write_to_json(self.save_log_path, self.config, mode="a")

def check_format(self, report):
while True:
try:
Expand All @@ -76,5 +75,6 @@ def check_format(self, report):
elif result_dict["accept"] is False or result_dict["accept"] == "false":
return False
except Exception as e:
self._callback_manager.on_run_error("格式检查", str(e))
logger.error(e)
continue
62 changes: 29 additions & 33 deletions erniebot-agent/applications/erniebot_researcher/research_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from collections import OrderedDict
from typing import Optional

from tools.utils import add_citation, erniebot_chat, write_to_json
from tools.utils import ReportCallbackHandler, add_citation, erniebot_chat

from erniebot_agent.agents.agent import Agent
from erniebot_agent.prompt import PromptTemplate

logger = logging.getLogger(__name__)

SUMMARIZE_MAX_LENGTH = 1800

SELECT_PROMPT = """
Expand All @@ -17,7 +17,7 @@
"""


class ResearchAgent:
class ResearchAgent(Agent):
"""
ResearchAgent, refer to
https://github.com/assafelovic/gpt-researcher/blob/master/examples/permchain_agents/research_team.py
Expand All @@ -40,13 +40,12 @@ def __init__(
citation_tool,
summarize_tool,
faiss_name_citation,
config=[],
system_message: Optional[str] = None,
use_outline=True,
use_context_planning=True,
save_log_path=None,
nums_queries=4,
embeddings=None,
callbacks=None,
):
"""
Initialize the ResearchAgent class.
Expand All @@ -56,10 +55,9 @@ def __init__(
......
"""
self.name = name
self.system_message = system_message or self.DEFAULT_SYSTEM_MESSAGE
self.system_message = system_message or self.DEFAULT_SYSTEM_MESSAGE # type: ignore
self.dir_path = dir_path
self.report_type = report_type
self.cfg = config
self.retriever = retriever_tool
self.retriever_abstract = retriever_abstract_tool
self.intent_detection = intent_detection_tool
Expand All @@ -72,12 +70,14 @@ def __init__(
self.use_outline = use_outline
self.agent_name = agent_name
self.faiss_name_citation = faiss_name_citation
self.config = config
self.save_log_path = save_log_path
self.use_context_planning = use_context_planning
self.nums_queries = nums_queries
self.select_prompt = PromptTemplate(SELECT_PROMPT, input_variables=["queries", "question"])
self.embeddings = embeddings
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
self._callback_manager = callbacks

async def run_search_summary(self, query):
responses = []
Expand All @@ -95,24 +95,25 @@ async def run_search_summary(self, query):
value = doc["url"]
url_dict[key] = value
else:
logger.warning(f"summary size exceed {SUMMARIZE_MAX_LENGTH}")
print(f"summary size exceed {SUMMARIZE_MAX_LENGTH}")
break
return responses, url_dict

async def run(self, query):
async def _run(self, query):
"""
Runs the ResearchAgent
Returns:
Report
"""
logger.info(f"🔎 Running research for '{query}'...")
self.config.append(("开始", f"🔎 Running research for '{query}'..."))
self.save_log()
await self._callback_manager.on_run_start(
agent_name=self.name, query=f"🔎 Running research for '{query}'..."
)
# Generate Agent
result = await self.intent_detection(query)
self.agent, self.role = result["agent"], result["agent_role_prompt"]
self.config.append((None, self.agent + self.role))
self.save_log()
await self._callback_manager.on_run_tool(
tool_name=self.intent_detection.description, response=self.agent + self.role
)
if self.use_context_planning:
sub_queries = []
res = self.retriever_abstract.search(query, top_k=3)
Expand Down Expand Up @@ -147,19 +148,17 @@ async def run(self, query):
sub_queries = await self.task_planning(
question=query, agent_role_prompt=self.role, context=context
)
self.config.append(("任务分解", "\n".join(sub_queries)))
self.save_log()
await self._callback_manager.on_run_tool(
tool_name=self.task_planning.description, response="\n".join(sub_queries)
)
# Run Sub-Queries
meta_data = OrderedDict()
# research_summary = ""
paragraphs_item = []
# summary_list=[]
for sub_query in sub_queries:
research_result, url_dict = await self.run_search_summary(sub_query)
meta_data.update(url_dict)
paragraphs_item.extend(research_result)
self.config.append((sub_query, f"{research_result}\n\n"))
self.save_log()
await self._callback_manager.on_run_tool(tool_name=sub_query, response=f"{research_result}\n\n")
paragraphs = []
for item in paragraphs_item:
if item not in paragraphs:
Expand All @@ -169,8 +168,7 @@ async def run(self, query):
# Generate Outline
if self.use_outline:
outline = await self.outline(sub_queries, query)
self.config.append(("报告大纲", outline))
self.save_log()
await self._callback_manager.on_run_tool(tool_name=self.outline.description, response=outline)
else:
outline = None
# Conduct Research
Expand All @@ -186,19 +184,17 @@ async def run(self, query):
)
break
except Exception as e:
logger.error(e)
self.config.append(("报错", str(e)))
await self._callback_manager.on_run_error(
tool_name=self.report_writing.description, error_information=str(e)
)
continue
self.config.append(("草稿", report))
self.save_log()
await self._callback_manager.on_run_tool(tool_name=self.report_writing.description, response=report)
# Generate Citations
citation_search = add_citation(paragraphs, self.faiss_name_citation, self.embeddings)
final_report, path = await self.citation(
report, url_index, self.agent_name, self.report_type, self.dir_path, citation_search
)
self.config.append(("草稿加引用", report))
self.save_log()
await self._callback_manager.on_run_tool(tool_name=self.citation.description, response=final_report)
await self._callback_manager.on_run_end(tool_name=self.name, response=f"报告存储在{path}")
breakpoint()
return final_report, path

def save_log(self):
write_to_json(self.save_log_path, self.config)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Optional

from tools.utils import erniebot_chat, write_to_json
from tools.utils import ReportCallbackHandler, erniebot_chat

from erniebot_agent.agents.agent import Agent
from erniebot_agent.prompt.prompt_template import PromptTemplate
Expand All @@ -19,34 +19,37 @@ def __init__(
name: str,
llm: str = "erine-4.0",
system_message: Optional[str] = None,
config: list = [],
save_log_path=None,
callbacks=None,
):
self.name = name
self.system_message = system_message or self.DEFAULT_SYSTEM_MESSAGE
self.model = llm
self.template = "草稿:\n\n{{draft}}" + "编辑的备注:\n\n{{notes}}"
self.prompt_template = PromptTemplate(template=self.template, input_variables=["draft", "notes"])
self.config = config
self.save_log_path = save_log_path
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
self._callback_manager = callbacks

async def _run(self, draft, notes):
self._callback_manager.on_run_start(self.name, "")
messages = [
{
"role": "user",
"content": self.prompt_template.format(draft=draft, notes=notes).replace(". ", "."),
}
]
if len(messages[0]["content"]) > 4800:
model = "ernie-longtext"
else:
model = "ernie-4.0"
while True:
try:
report = erniebot_chat(messages=messages, system=self.system_message)
report = erniebot_chat(messages=messages, system=self.system_message, model=model)
self.config.append(("修订后的报告", report))
self.save_log()
self._callback_manager.on_run_end(self.name, report)
return report
except Exception as e:
logger.error(e)
self.config.append(("报错信息", e))
self._callback_manager.on_run_error(self.name, str(e))
continue

def save_log(self):
write_to_json(self.save_log_path, self.config, mode="a")
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,5 @@ async def __call__(
if meta_data:
for index, (key, val) in enumerate(meta_data.items()):
url_index[val] = {"name": key, "index": index + 1}
# final_report=postprocess(final_report)
return final_report, url_index

0 comments on commit 9063a7e

Please sign in to comment.