Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Agent][Feat] Ensure coroutine safety #282

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/modules/file.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ from erniebot_agent.file import GlobalFileManagerHandler
async def demo_function():
file_manager = GlobalFileManagerHandler().get()
# 通过fileid搜索文件
file = file_manager.look_up_file_by_id(file_id='your_file_id')
file = await file_manager.look_up_file_by_id(file_id='your_file_id')
# 读取file内容(bytes)
file_content = await file.read_contents()
# 写出到指定位置,your_willing_path需要具体到文件名
Expand Down
27 changes: 8 additions & 19 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Final,
Iterable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Expand All @@ -32,7 +31,6 @@
from erniebot_agent.memory.messages import Message, SystemMessage
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.tool_manager import ToolManager
from erniebot_agent.utils.exceptions import FileError

_PLUGINS_WO_FILE_IO: Final[Tuple[str]] = ("eChart",)

Expand Down Expand Up @@ -106,6 +104,7 @@ def __init__(
self._file_manager = file_manager or get_default_file_manager()
self._plugins = plugins
self._init_file_needs_url()
self._is_running = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑agent被并发调用的情况:

def create_agent_run_task(prompt):
    return asyncio.create_task(agent.run(prompt))

create_agent_run_task(prompt1)
create_agent_run_task(prompt2)

在以上代码中,用户可能希望向agent派发任务,而这些任务将被并发执行。此处存在race condition:由于agent是有状态的(带有memory),两个任务都执行完成后、乃至执行过程中agent的状态将与两个任务的实际执行顺序与时机有关。

为了解决这个问题,我们不妨为Agent类引入一个属性_is_running,用这个属性来控制同一时刻agent只能执行一个任务。


@final
async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
Expand All @@ -119,8 +118,9 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
Returns:
Response from the agent.
"""
if files:
await self._ensure_managed_files(files)
if self._is_running:
raise RuntimeError("The agent is already running.")
self._is_running = True
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
try:
agent_resp = await self._run(prompt, files)
Expand All @@ -129,6 +129,8 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
raise e
else:
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
finally:
self._is_running = False
return agent_resp

@final
Expand Down Expand Up @@ -247,10 +249,10 @@ async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
# XXX: Sniffing is less efficient and probably unnecessary.
# Can we make a protocol to statically recognize file inputs and outputs
# or can we have the tools introspect about this?
input_files = file_manager.sniff_and_extract_files_from_dict(parsed_tool_args)
input_files = await file_manager.sniff_and_extract_files_from_obj(parsed_tool_args)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处存在bug:sniff_and_extract_files_from_list不能递归地找到parsed_tool_args中可能存在的所有file,而只能侦测出位于top-level的file。将sniff_and_extract_files_from_list修改为sniff_and_extract_files_from_obj以解决这个问题。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块我在 #292 中已经解决了,合入之后你这块 update 一下就行了。

input_files 和 output_files 这块都可能需要调整一下。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯

tool_ret = await tool(**parsed_tool_args)
if isinstance(tool_ret, dict):
output_files = file_manager.sniff_and_extract_files_from_dict(tool_ret)
output_files = await file_manager.sniff_and_extract_files_from_obj(tool_ret)
else:
output_files = []
tool_ret_json = json.dumps(tool_ret, ensure_ascii=False)
Expand All @@ -275,16 +277,3 @@ def _parse_tool_args(self, tool_args: str) -> Dict[str, Any]:
if not isinstance(args_dict, dict):
raise ValueError(f"`tool_args` cannot be interpreted as a dict. `tool_args`: {tool_args}")
return args_dict

async def _ensure_managed_files(self, files: Sequence[File]) -> None:
def _raise_exception(file: File) -> NoReturn:
raise FileError(f"{repr(file)} is not managed by the file manager of the agent.")

file_manager = self.get_file_manager()
for file in files:
try:
managed_file = file_manager.look_up_file_by_id(file.id)
except FileError:
_raise_exception(file)
if file is not managed_file:
_raise_exception(file)
6 changes: 3 additions & 3 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ async def _step(
PluginStep(
info=output_message.plugin_info,
result=output_message.content,
input_files=file_manager.sniff_and_extract_files_from_text(
chat_history[-1].content
input_files=await file_manager.sniff_and_extract_files_from_text(
input_messages[-1].content
), # TODO: make sure this is correct.
output_files=file_manager.sniff_and_extract_files_from_text(output_message.content),
output_files=[],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plugin不具备处理file ID的功能,所以不应该试图从output_message中提取file ID。

),
new_messages,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def _upload(file, history):
history = history + [((single_file.name,), None)]
size = len(file)

output_lis = file_manager.list_registered_files()
output_lis = await file_manager.list_files()
item = ""
for i in range(len(output_lis) - size):
item += f'<li>{str(output_lis[i]).strip("<>")}</li>'
Expand Down
2 changes: 1 addition & 1 deletion erniebot-agent/src/erniebot_agent/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
>>> file_manager = GlobalFileManagerHandler().get()
>>> local_file = await file_manager.create_file_from_path(file_path='your_path', file_type='local')

>>> file = file_manager.look_up_file_by_id(file_id='your_file_id')
>>> file = await file_manager.look_up_file_by_id(file_id='your_file_id')
>>> file_content = await file.read_contents() # get file content(bytes)
>>> await local_file.write_contents_to('your_willing_path') # save to location you want
"""
Expand Down