Skip to content

Commit

Permalink
[New Feature] 支持 stream agent 流式返回step和message (#345)
Browse files Browse the repository at this point in the history
* stream agent 流式返回step和message

* 使用make format-check进行代码优化

* 使用make format-check检查代码、使用make format优化代码

* 使用`python -m mypy src`检查和优化代码

* 使用`make lint`检查优化代码

* 根据review意见进行修改

* 使用`python -m black --check`检查格式

* 判断typing.TYPE_CHECKING来处理编译器检查
  • Loading branch information
xiabo0816 committed May 7, 2024
1 parent 621ce45 commit 3f2ecc0
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 15 deletions.
128 changes: 126 additions & 2 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
import logging
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Final,
Iterable,
Expand All @@ -20,7 +22,13 @@
from erniebot_agent.agents.callback.default import get_default_callbacks
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.mixins import GradioMixin
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
from erniebot_agent.agents.schema import (
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
LLMResponse,
ToolResponse,
)
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file import (
File,
Expand Down Expand Up @@ -131,13 +139,46 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
return agent_resp

@final
async def run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent asynchronously, returning an async iterator of responses.
Args:
prompt: A natural language text describing the task that the agent
should perform.
files: A list of files that the agent can use to perform the task.
Returns:
Iterator of responses from the agent.
"""
if files:
await self._ensure_managed_files(files)
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
try:
async for step, msg in self._run_stream(prompt, files):
yield (step, msg)
except BaseException as e:
await self._callback_manager.on_run_error(agent=self, error=e)
raise e
else:
await self._callback_manager.on_run_end(
agent=self,
response=AgentResponse(
text="Agent run stopped.",
chat_history=self.memory.get_messages(),
steps=[step],
status="STOPPED",
),
)

@final
async def run_llm(
self,
messages: List[Message],
**llm_opts: Any,
) -> LLMResponse:
"""Run the LLM asynchronously.
"""Run the LLM asynchronously, returning final response.
Args:
messages: The input messages.
Expand All @@ -156,6 +197,34 @@ async def run_llm(
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return llm_resp

@final
async def run_llm_stream(
self,
messages: List[Message],
**llm_opts: Any,
) -> AsyncIterator[LLMResponse]:
"""Run the LLM asynchronously, returning an async iterator of responses
Args:
messages: The input messages.
llm_opts: Options to pass to the LLM.
Returns:
Iterator of responses from the LLM.
"""
llm_resp = None
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
try:
# The LLM will return an async iterator.
async for llm_resp in self._run_llm_stream(messages, **(llm_opts or {})):
yield llm_resp
except (Exception, KeyboardInterrupt) as e:
await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e)
raise e
else:
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return

@final
async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
"""Run the specified tool asynchronously.
Expand Down Expand Up @@ -221,7 +290,32 @@ def get_file_manager(self) -> FileManager:
async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
raise NotImplementedError

@abc.abstractmethod
async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""
Abstract asynchronous generator method that should be implemented by subclasses.
This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given
prompt and optionally accompanying files.
"""
if TYPE_CHECKING:
# HACK
# This conditional block is strictly for static type-checking purposes (e.g., mypy)
# and will not be executed.
only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, [])
yield only_for_mypy_type_check

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
"""Run the LLM with the given messages and options.
Args:
messages: The input messages.
opts: Options to pass to the LLM.
Returns:
Response from the LLM.
"""
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")
Expand All @@ -241,6 +335,36 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts)
return LLMResponse(message=llm_ret)

async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]:
"""Run the LLM, yielding an async iterator of responses.
Args:
messages: The input messages.
opts: Options to pass to the LLM.
Returns:
Async iterator of responses from the LLM.
"""
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")

if "functions" not in opts:
functions = self._tool_manager.get_tool_schemas()
else:
functions = opts.pop("functions")

if hasattr(self.llm, "system"):
_logger.warning(
"The `system` message has already been set in the agent;"
"the `system` message configured in ERNIEBot will become ineffective."
)
opts["system"] = self.system.content if self.system is not None else None
opts["plugins"] = self._plugins
llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts)
async for msg in llm_ret:
yield LLMResponse(message=msg)

async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
parsed_tool_args = self._parse_tool_args(tool_args)
file_manager = self.get_file_manager()
Expand Down
140 changes: 127 additions & 13 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
# limitations under the License.

import logging
from typing import Final, Iterable, List, Optional, Sequence, Tuple, Union
from typing import (
AsyncIterator,
Final,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)

from erniebot_agent.agents.agent import Agent
from erniebot_agent.agents.callback.callback_manager import CallbackManager
Expand All @@ -31,7 +40,12 @@
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file import File, FileManager
from erniebot_agent.memory import Memory
from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message
from erniebot_agent.memory.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
Message,
)
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.tool_manager import ToolManager

Expand Down Expand Up @@ -136,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._step(chat_history, selected_tool=tool)
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
Expand Down Expand Up @@ -167,23 +181,122 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
response = self._create_stopped_response(chat_history, steps_taken)
return response

async def _step(
async def _call_first_tools(
self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None
) -> Tuple[AgentStep, List[Message]]:
new_messages: List[Message] = []
input_messages = self.memory.get_messages() + chat_history
if selected_tool is not None:
tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
llm_resp = await self.run_llm(
messages=input_messages,
functions=[selected_tool.function_call_schema()], # only regist one tool
tool_choice=tool_choice,
)
else:
if selected_tool is None:
llm_resp = await self.run_llm(messages=input_messages)
return await self._process_step(llm_resp, chat_history)

tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
llm_resp = await self.run_llm(
messages=input_messages,
functions=[selected_tool.function_call_schema()], # only regist one tool
tool_choice=tool_choice,
)
return await self._process_step(llm_resp, chat_history)

async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]:
"""Run a step of the agent.
Args:
chat_history: The chat history to provide to the agent.
Returns:
A tuple of an agent step and a list of new messages.
"""
input_messages = self.memory.get_messages() + chat_history
llm_resp = await self.run_llm(messages=input_messages)
return await self._process_step(llm_resp, chat_history)

async def _step_stream(
self, chat_history: List[Message]
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run a step of the agent in streaming mode.
Args:
chat_history: The chat history to provide to the agent.
Returns:
An async iterator that yields a tuple of an agent step and a list ofnew messages.
"""
input_messages = self.memory.get_messages() + chat_history
async for llm_resp in self.run_llm_stream(messages=input_messages):
yield await self._process_step(llm_resp, chat_history)

async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent with the given prompt and files in streaming mode.
Args:
prompt: The prompt for the agent to run.
files: A list of files for the agent to use. If `None`, use an empty
list.
Returns:
If `stream` is `False`, an agent response object. If `stream` is
`True`, an async iterator that yields agent steps one by one.
"""
chat_history: List[Message] = []
steps_taken: List[AgentStep] = []

run_input = await HumanMessage.create_with_files(
prompt, files or [], include_file_urls=self.file_needs_url
)

num_steps_taken = 0
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
steps_taken.append(curr_step)
else:
# If tool choice not work, skip this round
_logger.warning(f"Selected tool [{tool.tool_name}] not work")

is_finished = False
new_messages = []
end_step_msgs = []
while is_finished is False:
# IMPORTANT~! We use following code to get the response from LLM
# When finish_reason is fuction_call, run_llm_stream return all info in one step, but
# When finish_reason is normal chat, run_llm_stream return info in multiple steps.
async for curr_step, new_messages in self._step_stream(chat_history):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)
yield curr_step, new_messages

elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP
yield curr_step, new_messages

elif isinstance(curr_step, EndStep):
is_finished = True
end_step_msgs.extend(new_messages)
yield curr_step, new_messages
else:
raise RuntimeError("Invalid step type")
chat_history.extend(new_messages)

self.memory.add_message(run_input)
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
self.memory.add_message(end_step_msg)

async def _process_step(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
"""Process and execute a step of the agent from LLM response.
Args:
llm_resp: The LLM response to convert.
chat_history: The chat history to provide to the agent.
Returns:
A tuple of an agent step and a list of new messages.
"""
new_messages: List[Message] = []
output_message = llm_resp.message # AIMessage
new_messages.append(output_message)
# handle function call
if output_message.function_call is not None:
tool_name = output_message.function_call["name"]
tool_args = output_message.function_call["arguments"]
Expand All @@ -198,6 +311,7 @@ async def _step(
),
new_messages,
)
# handle plugin info with input/output files
elif output_message.plugin_info is not None:
file_manager = self.get_file_manager()
return (
Expand Down

0 comments on commit 3f2ecc0

Please sign in to comment.