Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]Add response format #309

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 30 additions & 1 deletion erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import logging
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -44,6 +45,9 @@
_T = TypeVar("_T", AIMessage, AIMessageChunk)


_logger = logging.getLogger(__name__)


class BaseERNIEBot(ChatModel):
@overload
async def chat(
Expand Down Expand Up @@ -215,7 +219,15 @@ def _generate_config(self, messages: List[Message], functions, **kwargs) -> dict
if functions is not None:
cfg_dict["functions"] = functions

name_list = ["top_p", "temperature", "penalty_score", "system", "plugins", "tool_choice"]
name_list = [
"top_p",
"temperature",
"penalty_score",
"system",
"plugins",
"tool_choice",
"response_format",
]
for name in name_list:
if name in kwargs:
cfg_dict[name] = kwargs[name]
Expand All @@ -227,6 +239,23 @@ def _generate_config(self, messages: List[Message], functions, **kwargs) -> dict
# rm blank dict
if not cfg_dict["tool_choice"]:
cfg_dict.pop("tool_choice")

if "response_format" in cfg_dict:
if cfg_dict["response_format"] not in ("json_object", "text"):
if "json" in cfg_dict["response_format"]:
cfg_dict["response_format"] = "json_object"
_logger.warning(
f"`response_format` has invalid value:`{cfg_dict['response_format']}`, "
"use `json_object` instead. "
)
else:
# It will not raise error in request
_logger.warning(
f"`response_format` has invalid value:`{cfg_dict['response_format']}`, "
"use default value: `text`. "
"You can only choose `json_object` or `text`. "
)

return cfg_dict

def _maybe_validate_qianfan_auth(self) -> None:
Expand Down
12 changes: 12 additions & 0 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> "ChatCompletionResponse":
...
Expand All @@ -141,6 +142,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Iterator["ChatCompletionResponse"]:
...
Expand All @@ -167,6 +169,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
...
Expand All @@ -192,6 +195,7 @@ def create(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
response_format: Optional[Literal["json_object", "text"]] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
"""Creates a model response for the given conversation.
Expand Down Expand Up @@ -238,6 +242,7 @@ def create(
user_id=user_id,
tool_choice=tool_choice,
stream=stream,
response_format=response_format,
)
kwargs["validate_functions"] = validate_functions
if extra_params is not None:
Expand Down Expand Up @@ -271,6 +276,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请保持和其他参数一致,使用 "..."

_config_: Optional[ConfigDictType] = ...,
) -> EBResponse:
...
Expand All @@ -297,6 +303,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请保持和其他参数一致,使用 "..."

_config_: Optional[ConfigDictType] = ...,
) -> AsyncIterator["ChatCompletionResponse"]:
...
Expand All @@ -323,6 +330,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请保持和其他参数一致,使用 "..."

_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
...
Expand All @@ -348,6 +356,7 @@ async def acreate(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
response_format: Optional[Literal["json_object", "text"]] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
"""Creates a model response for the given conversation.
Expand Down Expand Up @@ -394,6 +403,7 @@ async def acreate(
user_id=user_id,
tool_choice=tool_choice,
stream=stream,
response_format=response_format,
)
kwargs["validate_functions"] = validate_functions
if extra_params is not None:
Expand Down Expand Up @@ -438,6 +448,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
"extra_params",
"headers",
"request_timeout",
"response_format",
}

invalid_keys = kwargs.keys() - valid_keys
Expand Down Expand Up @@ -500,6 +511,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
_set_val_if_key_exists(kwargs, params, "user_id")
_set_val_if_key_exists(kwargs, params, "tool_choice")
_set_val_if_key_exists(kwargs, params, "stream")
_set_val_if_key_exists(kwargs, params, "response_format")
if "extra_params" in kwargs:
params.update(kwargs["extra_params"])

Expand Down