From 0486d4f972a092f1d84ace56fd8ab957124869e6 Mon Sep 17 00:00:00 2001 From: Brian Krabach Date: Wed, 6 Nov 2024 15:50:24 +0000 Subject: [PATCH] renames helper functions in openai_client --- .../chat-driver/chat_driver/chat_driver.py | 14 ++++----- .../in_memory_message_history_provider.py | 16 +++++----- .../local_message_history_provider.py | 16 +++++----- .../tests/formatted_instructions_test.py | 4 +-- .../openai-client/.vscode/settings.json | 3 +- .../openai-client/openai_client/completion.py | 8 ++--- .../openai-client/openai_client/errors.py | 10 +++---- .../openai-client/openai_client/logging.py | 6 ++-- .../openai-client/openai_client/messages.py | 26 ++++++++-------- .../openai-client/openai_client/tools.py | 30 +++++++++++-------- 10 files changed, 69 insertions(+), 64 deletions(-) diff --git a/libraries/python/chat-driver/chat_driver/chat_driver.py b/libraries/python/chat-driver/chat_driver/chat_driver.py index 6f21f155..0a0044ba 100644 --- a/libraries/python/chat-driver/chat_driver/chat_driver.py +++ b/libraries/python/chat-driver/chat_driver/chat_driver.py @@ -11,10 +11,10 @@ ChatCompletionUserMessageParam, ) from openai.types.chat.completion_create_params import ResponseFormat -from openai_client.completion import TEXT_RESPONSE_FORMAT, completion_message_string +from openai_client.completion import TEXT_RESPONSE_FORMAT, message_string_from_completion from openai_client.errors import CompletionError -from openai_client.messages import MessageFormatter, format_message -from openai_client.tools import complete_with_tool_calls, function_registry_to_tools, tool_choice +from openai_client.messages import MessageFormatter, format_with_dict +from openai_client.tools import complete_with_tool_calls, function_list_to_tools, function_registry_to_tools from pydantic import BaseModel from .local_message_history_provider import ( @@ -72,7 +72,7 @@ def __init__(self, config: ChatDriverConfig) -> None: self.instructions: list[str] = ( config.instructions if isinstance(config.instructions, list) else [config.instructions] ) - self.instruction_formatter = config.instruction_formatter or format_message + self.instruction_formatter = config.instruction_formatter or format_with_dict # Now set up the OpenAI client and model. self.client = config.openai_client @@ -178,7 +178,7 @@ async def respond( "model": self.model, "messages": [*self._formatted_instructions(instruction_parameters), *(await self.message_provider.get())], "tools": function_registry_to_tools(self.function_registry), - "tool_choice": tool_choice(function_choice), + "tool_choice": function_list_to_tools(function_choice), "response_format": response_format, } try: @@ -198,7 +198,7 @@ async def respond( # Return the response. return MessageEvent( - message=completion_message_string(completion) or None, + message=message_string_from_completion(completion) or None, metadata=metadata, ) @@ -216,7 +216,7 @@ def format_instructions( with the variables. This method returns a list of system messages formatted with the variables. """ - formatter = formatter or format_message + formatter = formatter or format_with_dict instruction_messages: list[ChatCompletionSystemMessageParam] = [] for instruction in instructions: if vars: diff --git a/libraries/python/chat-driver/chat_driver/in_memory_message_history_provider.py b/libraries/python/chat-driver/chat_driver/in_memory_message_history_provider.py index 9314a383..c8a09106 100644 --- a/libraries/python/chat-driver/chat_driver/in_memory_message_history_provider.py +++ b/libraries/python/chat-driver/chat_driver/in_memory_message_history_provider.py @@ -4,10 +4,10 @@ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCallParam from openai_client.messages import ( MessageFormatter, - assistant_message, - format_message, - system_message, - user_message, + create_assistant_message, + create_system_message, + create_user_message, + format_with_dict, ) @@ -17,7 +17,7 @@ def __init__( messages: list[ChatCompletionMessageParam] | None = None, formatter: MessageFormatter | None = None, ) -> None: - self.formatter: MessageFormatter = formatter or format_message + self.formatter: MessageFormatter = formatter or format_with_dict self.messages = messages or [] async def get(self) -> list[ChatCompletionMessageParam]: @@ -40,10 +40,10 @@ def delete_all(self) -> None: self.messages = [] def append_system_message(self, content: str, var: dict[str, Any] | None = None) -> None: - asyncio.run(self.append(system_message(content, var, self.formatter))) + asyncio.run(self.append(create_system_message(content, var, self.formatter))) def append_user_message(self, content: str, var: dict[str, Any] | None = None) -> None: - asyncio.run(self.append(user_message(content, var, self.formatter))) + asyncio.run(self.append(create_user_message(content, var, self.formatter))) def append_assistant_message( self, @@ -52,4 +52,4 @@ def append_assistant_message( tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None = None, var: dict[str, Any] | None = None, ) -> None: - asyncio.run(self.append(assistant_message(content, refusal, tool_calls, var, self.formatter))) + asyncio.run(self.append(create_assistant_message(content, refusal, tool_calls, var, self.formatter))) diff --git a/libraries/python/chat-driver/chat_driver/local_message_history_provider.py b/libraries/python/chat-driver/chat_driver/local_message_history_provider.py index 53cd6ec5..ffa82045 100644 --- a/libraries/python/chat-driver/chat_driver/local_message_history_provider.py +++ b/libraries/python/chat-driver/chat_driver/local_message_history_provider.py @@ -11,10 +11,10 @@ ) from openai_client.messages import ( MessageFormatter, - assistant_message, - format_message, - system_message, - user_message, + create_assistant_message, + create_system_message, + create_user_message, + format_with_dict, ) from .message_history_provider import MessageHistoryProviderProtocol @@ -36,7 +36,7 @@ def __init__(self, config: LocalMessageHistoryProviderConfig) -> None: self.data_dir = DEFAULT_DATA_DIR / "chat_driver" / config.context.session_id else: self.data_dir = Path(config.data_dir) - self.formatter: MessageFormatter = config.formatter or format_message + self.formatter: MessageFormatter = config.formatter or format_with_dict # Create the messages file if it doesn't exist. if not self.data_dir.exists(): @@ -79,10 +79,10 @@ def delete_all(self) -> None: self.messages_file.write_text("[]") async def append_system_message(self, content: str, var: dict[str, Any] | None = None) -> None: - await self.append(system_message(content, var, self.formatter)) + await self.append(create_system_message(content, var, self.formatter)) async def append_user_message(self, content: str, var: dict[str, Any] | None = None) -> None: - await self.append(user_message(content, var, self.formatter)) + await self.append(create_user_message(content, var, self.formatter)) async def append_assistant_message( self, @@ -91,4 +91,4 @@ async def append_assistant_message( tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None = None, var: dict[str, Any] | None = None, ) -> None: - await self.append(assistant_message(content, refusal, tool_calls, var, self.formatter)) + await self.append(create_assistant_message(content, refusal, tool_calls, var, self.formatter)) diff --git a/libraries/python/chat-driver/tests/formatted_instructions_test.py b/libraries/python/chat-driver/tests/formatted_instructions_test.py index a8331315..c79ef24c 100644 --- a/libraries/python/chat-driver/tests/formatted_instructions_test.py +++ b/libraries/python/chat-driver/tests/formatted_instructions_test.py @@ -1,5 +1,5 @@ from chat_driver.chat_driver import ChatDriver -from openai_client.messages import format_message +from openai_client.messages import format_with_dict def test_formatted_instructions() -> None: @@ -34,7 +34,7 @@ def test_formatted_instructions() -> None: "user_feedback": user_feedback, "chat_history": chat_history, }, - formatter=format_message, + formatter=format_with_dict, ) expected = [ diff --git a/libraries/python/openai-client/.vscode/settings.json b/libraries/python/openai-client/.vscode/settings.json index d393f820..77378705 100644 --- a/libraries/python/openai-client/.vscode/settings.json +++ b/libraries/python/openai-client/.vscode/settings.json @@ -38,6 +38,7 @@ }, "cSpell.words": [ "openai", - "Pydantic" + "Pydantic", + "tiktoken" ] } diff --git a/libraries/python/openai-client/openai_client/completion.py b/libraries/python/openai-client/openai_client/completion.py index b42051e4..dbbc3824 100644 --- a/libraries/python/openai-client/openai_client/completion.py +++ b/libraries/python/openai-client/openai_client/completion.py @@ -30,18 +30,18 @@ def assistant_message_from_completion(completion: ParsedChatCompletion[None]) -> return assistant_message -def completion_message(completion: ParsedChatCompletion) -> ParsedChatCompletionMessage | None: +def message_from_completion(completion: ParsedChatCompletion) -> ParsedChatCompletionMessage | None: return completion.choices[0].message if completion and completion.choices else None -def completion_message_string(completion: ParsedChatCompletion | None) -> str: +def message_string_from_completion(completion: ParsedChatCompletion | None) -> str: if not completion or not completion.choices or not completion.choices[0].message: return "" return completion.choices[0].message.content or "" -def completion_message_dict(completion: ParsedChatCompletion) -> dict[str, Any] | None: - message = completion_message(completion) +def message_dict_from_completion(completion: ParsedChatCompletion) -> dict[str, Any] | None: + message = message_from_completion(completion) if message: if message.parsed: if isinstance(message.parsed, BaseModel): diff --git a/libraries/python/openai-client/openai_client/errors.py b/libraries/python/openai-client/openai_client/errors.py index d5070768..a56260c8 100644 --- a/libraries/python/openai-client/openai_client/errors.py +++ b/libraries/python/openai-client/openai_client/errors.py @@ -11,24 +11,24 @@ ) -class InValidCompletionError(Exception): +class CompletionInvalidError(Exception): def __init__(self, message: str, body: dict[str, Any] | None = None) -> None: self.message = message self.body = body super().__init__(self.message) -class CompletionIsNoneError(InValidCompletionError): +class CompletionIsNoneError(CompletionInvalidError): def __init__(self) -> None: super().__init__("The completion response is None.") -class CompletionRefusedError(InValidCompletionError): +class CompletionRefusedError(CompletionInvalidError): def __init__(self, refusal: str) -> None: super().__init__(f"The model refused to complete the response: {refusal}", {"refusal": refusal}) -class CompletionWithoutStopError(InValidCompletionError): +class CompletionWithoutStopError(CompletionInvalidError): def __init__(self, finish_reason: str) -> None: super().__init__(f"The model did not complete the response: {finish_reason}", {"finish_reason": finish_reason}) @@ -47,7 +47,7 @@ def __init__(self, error: Exception) -> None: elif isinstance(error, APIStatusError): message = f"Another non-200-range status code was received. {error.status_code}: {error.message}" body = error.body - elif isinstance(error, InValidCompletionError): + elif isinstance(error, CompletionInvalidError): message = error.message body = error.body else: diff --git a/libraries/python/openai-client/openai_client/logging.py b/libraries/python/openai-client/openai_client/logging.py index fac8705b..00b88a69 100644 --- a/libraries/python/openai-client/openai_client/logging.py +++ b/libraries/python/openai-client/openai_client/logging.py @@ -8,7 +8,7 @@ from pydantic import BaseModel -def serializable_completion_args(completion_args: dict[str, Any]) -> dict[str, Any]: +def make_completion_args_serializable(completion_args: dict[str, Any]) -> dict[str, Any]: """ We put the completion args into logs and messages, so it's important that they are serializable. This function returns a copy of the completion args @@ -30,9 +30,9 @@ def serializable_completion_args(completion_args: dict[str, Any]) -> dict[str, A return sanitized -def extra_data(data: Any) -> dict[str, Any]: +def add_serializable_data(data: Any) -> dict[str, Any]: """ - Helper function to add extra data to log messages. + Helper function to use when adding extra data to log messages. """ extra = {} diff --git a/libraries/python/openai-client/openai_client/messages.py b/libraries/python/openai-client/openai_client/messages.py index 2f3e58ff..01a77ba0 100644 --- a/libraries/python/openai-client/openai_client/messages.py +++ b/libraries/python/openai-client/openai_client/messages.py @@ -41,7 +41,7 @@ def truncate_messages_for_logging( results.append(message) case list(): - compressed = process_list(content, maximum_content_length, filler_text) + compressed = apply_truncation_to_list(content, maximum_content_length, filler_text) message["content"] = compressed # type: ignore results.append(message) @@ -57,7 +57,7 @@ def truncate_string(string: str, maximum_length: int, filler_text: str) -> str: return string[:head_tail_length] + filler_text + string[-head_tail_length:] -def process_list(list_: list, maximum_length: int, filler_text: str) -> list: +def apply_truncation_to_list(list_: list, maximum_length: int, filler_text: str) -> list: for part in list_: for key, value in part.items(): match value: @@ -65,25 +65,25 @@ def process_list(list_: list, maximum_length: int, filler_text: str) -> list: part[key] = truncate_string(value, maximum_length, filler_text) case dict(): - part[key] = process_dict(value, maximum_length, filler_text) + part[key] = apply_truncation_to_dict(value, maximum_length, filler_text) return list_ -def process_dict(dict_: dict, maximum_length: int, filler_text: str) -> dict: +def apply_truncation_to_dict(dict_: dict, maximum_length: int, filler_text: str) -> dict: for key, value in dict_.items(): match value: case str(): dict_[key] = truncate_string(value, maximum_length, filler_text) case dict(): - dict_[key] = process_dict(value, maximum_length, filler_text) + dict_[key] = apply_truncation_to_dict(value, maximum_length, filler_text) return dict_ MessageFormatter = Callable[[str, dict[str, Any]], str] -def format_message(message: str, vars: dict[str, Any]) -> str: +def format_with_dict(message: str, vars: dict[str, Any]) -> str: """ Format a message with the given variables using the Python format method. """ @@ -96,7 +96,7 @@ def format_message(message: str, vars: dict[str, Any]) -> str: return message -def liquid_format(message: str, vars: dict[str, Any]) -> str: +def format_with_liquid(message: str, vars: dict[str, Any]) -> str: """ Format a message with the given variables using the Liquid template engine. """ @@ -108,28 +108,28 @@ def liquid_format(message: str, vars: dict[str, Any]) -> str: return out -def system_message( - content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_message +def create_system_message( + content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_with_dict ) -> ChatCompletionSystemMessageParam: if var: content = formatter(content, var) return {"role": "system", "content": content} -def user_message( - content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_message +def create_user_message( + content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_with_dict ) -> ChatCompletionUserMessageParam: if var: content = formatter(content, var) return {"role": "user", "content": content} -def assistant_message( +def create_assistant_message( content: str, refusal: Optional[str] = None, tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None = None, var: dict[str, Any] | None = None, - formatter: MessageFormatter = format_message, + formatter: MessageFormatter = format_with_dict, ) -> ChatCompletionAssistantMessageParam: if var: content = formatter(content, var) diff --git a/libraries/python/openai-client/openai_client/tools.py b/libraries/python/openai-client/openai_client/tools.py index 05a06032..64d6f59e 100644 --- a/libraries/python/openai-client/openai_client/tools.py +++ b/libraries/python/openai-client/openai_client/tools.py @@ -17,7 +17,7 @@ from . import logger from .completion import assistant_message_from_completion from .errors import CompletionError, validate_completion -from .logging import extra_data, serializable_completion_args +from .logging import add_serializable_data, make_completion_args_serializable async def execute_tool_call( @@ -31,16 +31,18 @@ async def execute_tool_call( if function_registry.has_function(function.name): logger.debug( "Function call.", - extra=extra_data({"name": function.name, "arguments": function.arguments}), + extra=add_serializable_data({"name": function.name, "arguments": function.arguments}), ) try: kwargs: dict[str, Any] = json.loads(function.arguments) value = await function_registry.execute_function_with_string_response(function.name, (), kwargs) except Exception as e: - logger.error("Error.", extra=extra_data({"error": e})) + logger.error("Error.", extra=add_serializable_data({"error": e})) value = f"Error: {e}" finally: - logger.debug("Function response.", extra=extra_data({"tool_call_id": tool_call.id, "content": value})) + logger.debug( + "Function response.", extra=add_serializable_data({"tool_call_id": tool_call.id, "content": value}) + ) return { "role": "tool", "content": value, @@ -68,7 +70,7 @@ def function_registry_to_tools(function_registry: FunctionRegistry) -> Iterable[ ] -def tool_choice(functions: list[str] | None) -> Iterable[ChatCompletionToolParam] | None: +def function_list_to_tools(functions: list[str] | None) -> Iterable[ChatCompletionToolParam] | None: if not functions: return None return [ @@ -100,19 +102,21 @@ async def complete_with_tool_calls( new_messages: list[ChatCompletionMessageParam] = [] # Completion call. - logger.debug("Completion call.", extra=extra_data(serializable_completion_args(completion_args))) - metadata["completion_args"] = serializable_completion_args(completion_args) + logger.debug("Completion call.", extra=add_serializable_data(make_completion_args_serializable(completion_args))) + metadata["completion_args"] = make_completion_args_serializable(completion_args) try: completion = await async_client.beta.chat.completions.parse( **completion_args, ) validate_completion(completion) - logger.debug("Completion response.", extra=extra_data({"completion": completion.model_dump()})) + logger.debug("Completion response.", extra=add_serializable_data({"completion": completion.model_dump()})) metadata["completion"] = completion.model_dump() except CompletionError as e: completion_error = CompletionError(e) metadata["completion_error"] = completion_error.message - logger.error(e.message, extra=extra_data({"completion_error": completion_error.body, "metadata": metadata})) + logger.error( + e.message, extra=add_serializable_data({"completion_error": completion_error.body, "metadata": metadata}) + ) raise completion_error from e # Extract response and add to messages. @@ -135,21 +139,21 @@ async def complete_with_tool_calls( final_args = {**completion_args, "messages": [*messages, *new_messages]} del final_args["tools"] del final_args["tool_choice"] - logger.debug("Tool completion call.", extra=extra_data(serializable_completion_args(final_args))) - metadata["tool_completion_args"] = serializable_completion_args(final_args) + logger.debug("Tool completion call.", extra=add_serializable_data(make_completion_args_serializable(final_args))) + metadata["tool_completion_args"] = make_completion_args_serializable(final_args) try: tool_completion = await async_client.beta.chat.completions.parse( **final_args, ) validate_completion(tool_completion) - logger.debug("Tool completion response.", extra=extra_data({"completion": completion.model_dump()})) + logger.debug("Tool completion response.", extra=add_serializable_data({"completion": completion.model_dump()})) metadata["completion"] = completion.model_dump() except Exception as e: tool_completion_error = CompletionError(e) metadata["tool_completion_error"] = tool_completion_error.message logger.error( tool_completion_error.message, - extra=extra_data({"completion_error": tool_completion_error.body, "metadata": metadata}), + extra=add_serializable_data({"completion_error": tool_completion_error.body, "metadata": metadata}), ) raise tool_completion_error from e