From d18cc16d7a6d7a7a1d33b03a5c7d1529437d5298 Mon Sep 17 00:00:00 2001 From: Sijun He Date: Fri, 31 May 2024 16:12:27 +0800 Subject: [PATCH 1/3] add new models; add response_format --- .../src/erniebot/resources/chat_completion.py | 41 +++++++++++++++---- erniebot/tests/test_chat_completion.py | 20 ++++++++- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index cd3fff0ce..6ae1fd550 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -58,15 +58,29 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-3.5-8k": { "model_id": "completions", }, + "ernie-3.5-8k-0205": { + "model_id": "ernie-3.5-8k-0205", + }, + "ernie-3.5-8k-0329": { + "model_id": "ernie-3.5-8k-0329", + }, + "ernie-3.5-128k": { + "model_id": "ernie-3.5-128k", + }, "ernie-lite": { "model_id": "eb-instant", }, + "ernie-lite-8k-0308": { + "model_id": "ernie-lite-8k", + }, "ernie-4.0": { "model_id": "completions_pro", }, - "ernie-longtext": { - # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 - "model_id": "completions", + "ernie-4.0-8k-0329": { + "model_id": "ernie-4.0-8k-0329", + }, + "ernie-4.0-8k-0104": { + "model_id": "ernie-4.0-8k-0104", }, "ernie-speed": { "model_id": "ernie_speed", @@ -97,10 +111,6 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-4.0": { "model_id": "completions_pro", }, - "ernie-longtext": { - # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 - "model_id": "completions", - }, "ernie-speed": { "model_id": "ernie_speed", }, @@ -156,6 +166,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> "ChatCompletionResponse": @@ -183,6 +194,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Iterator["ChatCompletionResponse"]: @@ -210,6 +222,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]: @@ -236,6 +249,7 @@ def create( extra_params: Optional[dict] = None, headers: Optional[HeadersType] = None, request_timeout: Optional[float] = None, + response_format: Optional[Literal["json_object", "text"]] = None, max_output_tokens: Optional[int] = None, _config_: Optional[ConfigDictType] = None, ) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]: @@ -292,6 +306,8 @@ def create( kwargs["headers"] = headers if request_timeout is not None: kwargs["request_timeout"] = request_timeout + if response_format is not None: + kwargs["response_format"] = response_format resp = resource.create_resource(**kwargs) return transform(ChatCompletionResponse.from_mapping, resp) @@ -318,6 +334,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> EBResponse: @@ -345,6 +362,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> AsyncIterator["ChatCompletionResponse"]: @@ -372,6 +390,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]: @@ -398,6 +417,7 @@ async def acreate( extra_params: Optional[dict] = None, headers: Optional[HeadersType] = None, request_timeout: Optional[float] = None, + response_format: Optional[Literal["json_object", "text"]] = None, max_output_tokens: Optional[int] = None, _config_: Optional[ConfigDictType] = None, ) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]: @@ -423,6 +443,7 @@ async def acreate( validate_functions: Whether to validate the function descriptions. headers: Custom headers to send with the request. request_timeout: Timeout for a single request. + response_format: Format of the response. _config_: Overrides the global settings. Returns: @@ -460,9 +481,9 @@ async def acreate( def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None: if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"): - for arg in ("functions", "disable_search", "enable_citation", "tool_choice"): + for arg in ("functions", "disable_search", "enable_citation", "tool_choice", "response_format"): if arg in kwargs: - raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.") + raise errors.InvalidArgumentError(f"`{arg}` is not supported by the `{model_name}` model.") def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream: def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str: @@ -497,6 +518,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: "extra_params", "headers", "request_timeout", + "response_format", "max_output_tokens", } @@ -561,6 +583,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: _set_val_if_key_exists(kwargs, params, "tool_choice") _set_val_if_key_exists(kwargs, params, "stream") _set_val_if_key_exists(kwargs, params, "max_output_tokens") + _set_val_if_key_exists(kwargs, params, "response_format") if "extra_params" in kwargs: params.update(kwargs["extra_params"]) diff --git a/erniebot/tests/test_chat_completion.py b/erniebot/tests/test_chat_completion.py index d20d8eaba..ac3b2ade2 100644 --- a/erniebot/tests/test_chat_completion.py +++ b/erniebot/tests/test_chat_completion.py @@ -39,6 +39,18 @@ def create_chat_completion(model): print(response.get_result()) +def create_chat_completion_json_mode(model): + response = erniebot.ChatCompletion.create( + model=model, + messages=[ + {"role": "user", "content": "文心一言是哪个公司开发的?"}, + ], + stream=False, + response_format="json_object", + ) + print(response.get_result()) + + def create_chat_completion_stream(model): response = erniebot.ChatCompletion.create( model=model, @@ -67,6 +79,10 @@ def create_chat_completion_stream(model): erniebot.api_type = "qianfan" - create_chat_completion(model="ernie-turbo") + # create_chat_completion(model="ernie-turbo") + erniebot.ak = "gU71lRqGc8wmNHPZkqP9vToK" + erniebot.sk = "l4P9sGVjonxhA8F3WQWZDrWx21G4GKQT" + - create_chat_completion_stream(model="ernie-turbo") + # create_chat_completion_stream(model="ernie-turbo") + create_chat_completion_json_mode(model="ernie-lite") From 4819e38fd20233bbbc38e5f79eee2cbe13e16398 Mon Sep 17 00:00:00 2001 From: Sijun He Date: Fri, 31 May 2024 16:15:39 +0800 Subject: [PATCH 2/3] remove aksk --- erniebot/tests/test_chat_completion.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/erniebot/tests/test_chat_completion.py b/erniebot/tests/test_chat_completion.py index ac3b2ade2..073a44b44 100644 --- a/erniebot/tests/test_chat_completion.py +++ b/erniebot/tests/test_chat_completion.py @@ -79,10 +79,6 @@ def create_chat_completion_stream(model): erniebot.api_type = "qianfan" - # create_chat_completion(model="ernie-turbo") - erniebot.ak = "gU71lRqGc8wmNHPZkqP9vToK" - erniebot.sk = "l4P9sGVjonxhA8F3WQWZDrWx21G4GKQT" - - - # create_chat_completion_stream(model="ernie-turbo") + create_chat_completion(model="ernie-turbo") + create_chat_completion_stream(model="ernie-turbo") create_chat_completion_json_mode(model="ernie-lite") From 7dbcd729fcb453db1bbf22e73528cb21eb9dd2fe Mon Sep 17 00:00:00 2001 From: Sijun He Date: Fri, 31 May 2024 16:23:03 +0800 Subject: [PATCH 3/3] fix lint --- erniebot/src/erniebot/resources/chat_completion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index 6ae1fd550..3f91c4bd3 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -483,7 +483,9 @@ def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None: if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"): for arg in ("functions", "disable_search", "enable_citation", "tool_choice", "response_format"): if arg in kwargs: - raise errors.InvalidArgumentError(f"`{arg}` is not supported by the `{model_name}` model.") + raise errors.InvalidArgumentError( + f"`{arg}` is not supported by the `{model_name}` model." + ) def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream: def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str: