Skip to content

Commit

Permalink
renames helper functions in openai_client for clearer understanding f…
Browse files Browse the repository at this point in the history
…or reuse and consistency (microsoft#221)
  • Loading branch information
bkrabach authored Nov 6, 2024
1 parent 071eb73 commit c87f924
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 64 deletions.
14 changes: 7 additions & 7 deletions libraries/python/chat-driver/chat_driver/chat_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
ChatCompletionUserMessageParam,
)
from openai.types.chat.completion_create_params import ResponseFormat
from openai_client.completion import TEXT_RESPONSE_FORMAT, completion_message_string
from openai_client.completion import TEXT_RESPONSE_FORMAT, message_string_from_completion
from openai_client.errors import CompletionError
from openai_client.messages import MessageFormatter, format_message
from openai_client.tools import complete_with_tool_calls, function_registry_to_tools, tool_choice
from openai_client.messages import MessageFormatter, format_with_dict
from openai_client.tools import complete_with_tool_calls, function_list_to_tools, function_registry_to_tools
from pydantic import BaseModel

from .local_message_history_provider import (
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, config: ChatDriverConfig) -> None:
self.instructions: list[str] = (
config.instructions if isinstance(config.instructions, list) else [config.instructions]
)
self.instruction_formatter = config.instruction_formatter or format_message
self.instruction_formatter = config.instruction_formatter or format_with_dict

# Now set up the OpenAI client and model.
self.client = config.openai_client
Expand Down Expand Up @@ -178,7 +178,7 @@ async def respond(
"model": self.model,
"messages": [*self._formatted_instructions(instruction_parameters), *(await self.message_provider.get())],
"tools": function_registry_to_tools(self.function_registry),
"tool_choice": tool_choice(function_choice),
"tool_choice": function_list_to_tools(function_choice),
"response_format": response_format,
}
try:
Expand All @@ -198,7 +198,7 @@ async def respond(
# Return the response.

return MessageEvent(
message=completion_message_string(completion) or None,
message=message_string_from_completion(completion) or None,
metadata=metadata,
)

Expand All @@ -216,7 +216,7 @@ def format_instructions(
with the variables. This method returns a list of system messages formatted
with the variables.
"""
formatter = formatter or format_message
formatter = formatter or format_with_dict
instruction_messages: list[ChatCompletionSystemMessageParam] = []
for instruction in instructions:
if vars:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCallParam
from openai_client.messages import (
MessageFormatter,
assistant_message,
format_message,
system_message,
user_message,
create_assistant_message,
create_system_message,
create_user_message,
format_with_dict,
)


Expand All @@ -17,7 +17,7 @@ def __init__(
messages: list[ChatCompletionMessageParam] | None = None,
formatter: MessageFormatter | None = None,
) -> None:
self.formatter: MessageFormatter = formatter or format_message
self.formatter: MessageFormatter = formatter or format_with_dict
self.messages = messages or []

async def get(self) -> list[ChatCompletionMessageParam]:
Expand All @@ -40,10 +40,10 @@ def delete_all(self) -> None:
self.messages = []

def append_system_message(self, content: str, var: dict[str, Any] | None = None) -> None:
asyncio.run(self.append(system_message(content, var, self.formatter)))
asyncio.run(self.append(create_system_message(content, var, self.formatter)))

def append_user_message(self, content: str, var: dict[str, Any] | None = None) -> None:
asyncio.run(self.append(user_message(content, var, self.formatter)))
asyncio.run(self.append(create_user_message(content, var, self.formatter)))

def append_assistant_message(
self,
Expand All @@ -52,4 +52,4 @@ def append_assistant_message(
tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None = None,
var: dict[str, Any] | None = None,
) -> None:
asyncio.run(self.append(assistant_message(content, refusal, tool_calls, var, self.formatter)))
asyncio.run(self.append(create_assistant_message(content, refusal, tool_calls, var, self.formatter)))
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
)
from openai_client.messages import (
MessageFormatter,
assistant_message,
format_message,
system_message,
user_message,
create_assistant_message,
create_system_message,
create_user_message,
format_with_dict,
)

from .message_history_provider import MessageHistoryProviderProtocol
Expand All @@ -36,7 +36,7 @@ def __init__(self, config: LocalMessageHistoryProviderConfig) -> None:
self.data_dir = DEFAULT_DATA_DIR / "chat_driver" / config.context.session_id
else:
self.data_dir = Path(config.data_dir)
self.formatter: MessageFormatter = config.formatter or format_message
self.formatter: MessageFormatter = config.formatter or format_with_dict

# Create the messages file if it doesn't exist.
if not self.data_dir.exists():
Expand Down Expand Up @@ -79,10 +79,10 @@ def delete_all(self) -> None:
self.messages_file.write_text("[]")

async def append_system_message(self, content: str, var: dict[str, Any] | None = None) -> None:
await self.append(system_message(content, var, self.formatter))
await self.append(create_system_message(content, var, self.formatter))

async def append_user_message(self, content: str, var: dict[str, Any] | None = None) -> None:
await self.append(user_message(content, var, self.formatter))
await self.append(create_user_message(content, var, self.formatter))

async def append_assistant_message(
self,
Expand All @@ -91,4 +91,4 @@ async def append_assistant_message(
tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None = None,
var: dict[str, Any] | None = None,
) -> None:
await self.append(assistant_message(content, refusal, tool_calls, var, self.formatter))
await self.append(create_assistant_message(content, refusal, tool_calls, var, self.formatter))
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from chat_driver.chat_driver import ChatDriver
from openai_client.messages import format_message
from openai_client.messages import format_with_dict


def test_formatted_instructions() -> None:
Expand Down Expand Up @@ -34,7 +34,7 @@ def test_formatted_instructions() -> None:
"user_feedback": user_feedback,
"chat_history": chat_history,
},
formatter=format_message,
formatter=format_with_dict,
)

expected = [
Expand Down
3 changes: 2 additions & 1 deletion libraries/python/openai-client/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
},
"cSpell.words": [
"openai",
"Pydantic"
"Pydantic",
"tiktoken"
]
}
8 changes: 4 additions & 4 deletions libraries/python/openai-client/openai_client/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ def assistant_message_from_completion(completion: ParsedChatCompletion[None]) ->
return assistant_message


def completion_message(completion: ParsedChatCompletion) -> ParsedChatCompletionMessage | None:
def message_from_completion(completion: ParsedChatCompletion) -> ParsedChatCompletionMessage | None:
return completion.choices[0].message if completion and completion.choices else None


def completion_message_string(completion: ParsedChatCompletion | None) -> str:
def message_string_from_completion(completion: ParsedChatCompletion | None) -> str:
if not completion or not completion.choices or not completion.choices[0].message:
return ""
return completion.choices[0].message.content or ""


def completion_message_dict(completion: ParsedChatCompletion) -> dict[str, Any] | None:
message = completion_message(completion)
def message_dict_from_completion(completion: ParsedChatCompletion) -> dict[str, Any] | None:
message = message_from_completion(completion)
if message:
if message.parsed:
if isinstance(message.parsed, BaseModel):
Expand Down
10 changes: 5 additions & 5 deletions libraries/python/openai-client/openai_client/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@
)


class InValidCompletionError(Exception):
class CompletionInvalidError(Exception):
def __init__(self, message: str, body: dict[str, Any] | None = None) -> None:
self.message = message
self.body = body
super().__init__(self.message)


class CompletionIsNoneError(InValidCompletionError):
class CompletionIsNoneError(CompletionInvalidError):
def __init__(self) -> None:
super().__init__("The completion response is None.")


class CompletionRefusedError(InValidCompletionError):
class CompletionRefusedError(CompletionInvalidError):
def __init__(self, refusal: str) -> None:
super().__init__(f"The model refused to complete the response: {refusal}", {"refusal": refusal})


class CompletionWithoutStopError(InValidCompletionError):
class CompletionWithoutStopError(CompletionInvalidError):
def __init__(self, finish_reason: str) -> None:
super().__init__(f"The model did not complete the response: {finish_reason}", {"finish_reason": finish_reason})

Expand All @@ -47,7 +47,7 @@ def __init__(self, error: Exception) -> None:
elif isinstance(error, APIStatusError):
message = f"Another non-200-range status code was received. {error.status_code}: {error.message}"
body = error.body
elif isinstance(error, InValidCompletionError):
elif isinstance(error, CompletionInvalidError):
message = error.message
body = error.body
else:
Expand Down
6 changes: 3 additions & 3 deletions libraries/python/openai-client/openai_client/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel


def serializable_completion_args(completion_args: dict[str, Any]) -> dict[str, Any]:
def make_completion_args_serializable(completion_args: dict[str, Any]) -> dict[str, Any]:
"""
We put the completion args into logs and messages, so it's important that
they are serializable. This function returns a copy of the completion args
Expand All @@ -30,9 +30,9 @@ def serializable_completion_args(completion_args: dict[str, Any]) -> dict[str, A
return sanitized


def extra_data(data: Any) -> dict[str, Any]:
def add_serializable_data(data: Any) -> dict[str, Any]:
"""
Helper function to add extra data to log messages.
Helper function to use when adding extra data to log messages.
"""
extra = {}

Expand Down
26 changes: 13 additions & 13 deletions libraries/python/openai-client/openai_client/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def truncate_messages_for_logging(
results.append(message)

case list():
compressed = process_list(content, maximum_content_length, filler_text)
compressed = apply_truncation_to_list(content, maximum_content_length, filler_text)
message["content"] = compressed # type: ignore
results.append(message)

Expand All @@ -57,33 +57,33 @@ def truncate_string(string: str, maximum_length: int, filler_text: str) -> str:
return string[:head_tail_length] + filler_text + string[-head_tail_length:]


def process_list(list_: list, maximum_length: int, filler_text: str) -> list:
def apply_truncation_to_list(list_: list, maximum_length: int, filler_text: str) -> list:
for part in list_:
for key, value in part.items():
match value:
case str():
part[key] = truncate_string(value, maximum_length, filler_text)

case dict():
part[key] = process_dict(value, maximum_length, filler_text)
part[key] = apply_truncation_to_dict(value, maximum_length, filler_text)
return list_


def process_dict(dict_: dict, maximum_length: int, filler_text: str) -> dict:
def apply_truncation_to_dict(dict_: dict, maximum_length: int, filler_text: str) -> dict:
for key, value in dict_.items():
match value:
case str():
dict_[key] = truncate_string(value, maximum_length, filler_text)

case dict():
dict_[key] = process_dict(value, maximum_length, filler_text)
dict_[key] = apply_truncation_to_dict(value, maximum_length, filler_text)
return dict_


MessageFormatter = Callable[[str, dict[str, Any]], str]


def format_message(message: str, vars: dict[str, Any]) -> str:
def format_with_dict(message: str, vars: dict[str, Any]) -> str:
"""
Format a message with the given variables using the Python format method.
"""
Expand All @@ -96,7 +96,7 @@ def format_message(message: str, vars: dict[str, Any]) -> str:
return message


def liquid_format(message: str, vars: dict[str, Any]) -> str:
def format_with_liquid(message: str, vars: dict[str, Any]) -> str:
"""
Format a message with the given variables using the Liquid template engine.
"""
Expand All @@ -108,28 +108,28 @@ def liquid_format(message: str, vars: dict[str, Any]) -> str:
return out


def system_message(
content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_message
def create_system_message(
content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_with_dict
) -> ChatCompletionSystemMessageParam:
if var:
content = formatter(content, var)
return {"role": "system", "content": content}


def user_message(
content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_message
def create_user_message(
content: str, var: dict[str, Any] | None = None, formatter: MessageFormatter = format_with_dict
) -> ChatCompletionUserMessageParam:
if var:
content = formatter(content, var)
return {"role": "user", "content": content}


def assistant_message(
def create_assistant_message(
content: str,
refusal: Optional[str] = None,
tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None = None,
var: dict[str, Any] | None = None,
formatter: MessageFormatter = format_message,
formatter: MessageFormatter = format_with_dict,
) -> ChatCompletionAssistantMessageParam:
if var:
content = formatter(content, var)
Expand Down
Loading

0 comments on commit c87f924

Please sign in to comment.