Skip to content

Commit

Permalink
[API ADD]Add max_out_token (#343)
Browse files Browse the repository at this point in the history
* add `max_out_token`

* fix flake
  • Loading branch information
Southpika committed Apr 29, 2024
1 parent 8de8fdb commit 621ce45
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> "ChatCompletionResponse":
...
Expand All @@ -182,6 +183,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Iterator["ChatCompletionResponse"]:
...
Expand All @@ -208,6 +210,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
...
Expand All @@ -233,6 +236,7 @@ def create(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
max_output_tokens: Optional[int] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
"""Creates a model response for the given conversation.
Expand Down Expand Up @@ -279,6 +283,7 @@ def create(
user_id=user_id,
tool_choice=tool_choice,
stream=stream,
max_output_tokens=max_output_tokens,
)
kwargs["validate_functions"] = validate_functions
if extra_params is not None:
Expand Down Expand Up @@ -313,6 +318,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> EBResponse:
...
Expand All @@ -339,6 +345,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> AsyncIterator["ChatCompletionResponse"]:
...
Expand All @@ -365,6 +372,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
...
Expand All @@ -390,6 +398,7 @@ async def acreate(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
max_output_tokens: Optional[int] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
"""Creates a model response for the given conversation.
Expand Down Expand Up @@ -436,6 +445,7 @@ async def acreate(
user_id=user_id,
tool_choice=tool_choice,
stream=stream,
max_output_tokens=max_output_tokens,
)
kwargs["validate_functions"] = validate_functions
if extra_params is not None:
Expand All @@ -450,12 +460,7 @@ async def acreate(

def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None:
if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"):
for arg in (
"functions",
"disable_search",
"enable_citation",
"tool_choice",
):
for arg in ("functions", "disable_search", "enable_citation", "tool_choice"):
if arg in kwargs:
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")

Expand Down Expand Up @@ -492,6 +497,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
"extra_params",
"headers",
"request_timeout",
"max_output_tokens",
}

invalid_keys = kwargs.keys() - valid_keys
Expand Down Expand Up @@ -554,6 +560,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
_set_val_if_key_exists(kwargs, params, "user_id")
_set_val_if_key_exists(kwargs, params, "tool_choice")
_set_val_if_key_exists(kwargs, params, "stream")
_set_val_if_key_exists(kwargs, params, "max_output_tokens")

if "extra_params" in kwargs:
params.update(kwargs["extra_params"])

Expand Down

0 comments on commit 621ce45

Please sign in to comment.