From 84ead8420186b3ff01da231ff75a9d759c96a745 Mon Sep 17 00:00:00 2001 From: David Koleczek <45405824+DavidKoleczek@users.noreply.github.com> Date: Sat, 8 Feb 2025 17:16:47 +0000 Subject: [PATCH] openai streaming --- .vscode/launch.json | 2 +- poetry.lock | 28 +- pyproject.toml | 4 +- .../llm/chat_completion/__init__.py | 4 +- .../llm/chat_completion/interface.py | 32 +- .../chat_completion/providers/openai_api.py | 203 +++++++-- src/not_again_ai/llm/chat_completion/types.py | 44 ++ .../test_chat_completion_stream.py | 393 ++++++++++++++++++ 8 files changed, 661 insertions(+), 49 deletions(-) create mode 100644 tests/llm/chat_completion/test_chat_completion_stream.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 92390e4..9102e55 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,7 +5,7 @@ "version": "0.2.0", "configurations": [ { - "name": "Python: Current File", + "name": "Python Debugger: Current File", "type": "debugpy", "request": "launch", "program": "${file}", diff --git a/poetry.lock b/poetry.lock index 0558e6c..fae3d6c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2800,6 +2800,26 @@ pluggy = ">=1.5,<2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.25.3" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.9" +groups = ["test"] +markers = "python_version == \"3.11\" or python_version >= \"3.12\"" +files = [ + {file = "pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3"}, + {file = "pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-base-url" version = "2.1.0" @@ -2893,15 +2913,15 @@ six = ">=1.5" [[package]] name = "python-liquid" -version = "1.12.2" +version = "1.13.0" description = "A Python engine for the Liquid template language." optional = true python-versions = ">=3.7" groups = ["main"] markers = "python_version == \"3.11\" and extra == \"llm\" or python_version >= \"3.12\" and extra == \"llm\"" files = [ - {file = "python_liquid-1.12.2-py3-none-any.whl", hash = "sha256:4a2611ecb7d77476f2a028438cf02213873ec06b1223f661ba16eceec130e192"}, - {file = "python_liquid-1.12.2.tar.gz", hash = "sha256:46b5deb2337c1afe91760f2da46b05f17b55482023da441a4a78d096c4271d19"}, + {file = "python_liquid-1.13.0-py3-none-any.whl", hash = "sha256:843ed7b8af00c1480d1bf402553ed07bdddce130ebbf4b1fefc84bb2e076f5d4"}, + {file = "python_liquid-1.13.0.tar.gz", hash = "sha256:c158fbaad6dd41c49de7cff34e3611bff7211326fb3049322d27673e1d66e166"}, ] [package.dependencies] @@ -3864,4 +3884,4 @@ viz = ["numpy", "pandas", "seaborn"] [metadata] lock-version = "2.1" python-versions = ">=3.11, <3.13" -content-hash = "6861b50ce845a53d9318c10e475bf09e016878df22848300b7a39e233f5a73a0" +content-hash = "44a539d8fdd0304109197a76bb838eb143610eb0138881af60d1c3afb5e850be" diff --git a/pyproject.toml b/pyproject.toml index de2e564..f4ff05d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "not-again-ai" -version = "0.16.1" +version = "0.17.0" description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place." authors = [ { name = "DaveCoDev", email = "dave.co.dev@gmail.com" } @@ -70,6 +70,7 @@ nox-poetry = "*" [tool.poetry.group.test.dependencies] pytest = "*" +pytest-asyncio = "*" pytest-cov = "*" pytest-randomly = "*" @@ -153,6 +154,7 @@ filterwarnings = [ # "ignore::DeprecationWarning:typer", "ignore::pytest.PytestUnraisableExceptionWarning" ] +asyncio_mode = "auto" [tool.coverage.run] branch = true \ No newline at end of file diff --git a/src/not_again_ai/llm/chat_completion/__init__.py b/src/not_again_ai/llm/chat_completion/__init__.py index e268d4e..8d8d339 100644 --- a/src/not_again_ai/llm/chat_completion/__init__.py +++ b/src/not_again_ai/llm/chat_completion/__init__.py @@ -1,4 +1,4 @@ -from not_again_ai.llm.chat_completion.interface import chat_completion +from not_again_ai.llm.chat_completion.interface import chat_completion, chat_completion_stream from not_again_ai.llm.chat_completion.types import ChatCompletionRequest -__all__ = ["ChatCompletionRequest", "chat_completion"] +__all__ = ["ChatCompletionRequest", "chat_completion", "chat_completion_stream"] diff --git a/src/not_again_ai/llm/chat_completion/interface.py b/src/not_again_ai/llm/chat_completion/interface.py index b135d48..522b03c 100644 --- a/src/not_again_ai/llm/chat_completion/interface.py +++ b/src/not_again_ai/llm/chat_completion/interface.py @@ -1,9 +1,9 @@ -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from typing import Any from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion -from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion -from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, ChatCompletionResponse +from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream +from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse def chat_completion( @@ -30,3 +30,29 @@ def chat_completion( return ollama_chat_completion(request, client) else: raise ValueError(f"Provider {provider} not supported") + + +async def chat_completion_stream( + request: ChatCompletionRequest, + provider: str, + client: Callable[..., Any], +) -> AsyncGenerator[ChatCompletionChunk, None]: + """Stream a chat completion response from the given provider. Currently supported providers: + - `openai` - OpenAI + - `azure_openai` - Azure OpenAI + - `ollama` - Ollama + + Args: + request: Request parameter object + provider: The supported provider name + client: Client information, see the provider's implementation for what can be provided + + Returns: + AsyncGenerator[ChatCompletionChunk, None] + """ + request.stream = True + if provider == "openai" or provider == "azure_openai": + async for chunk in openai_chat_completion_stream(request, client): + yield chunk + else: + raise ValueError(f"Provider {provider} not supported") diff --git a/src/not_again_ai/llm/chat_completion/providers/openai_api.py b/src/not_again_ai/llm/chat_completion/providers/openai_api.py index b3f63ca..d252c11 100644 --- a/src/not_again_ai/llm/chat_completion/providers/openai_api.py +++ b/src/not_again_ai/llm/chat_completion/providers/openai_api.py @@ -1,17 +1,23 @@ -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable, Coroutine import json import time from typing import Any, Literal from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from openai import AzureOpenAI, OpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from not_again_ai.llm.chat_completion.types import ( AssistantMessage, ChatCompletionChoice, + ChatCompletionChoiceStream, + ChatCompletionChunk, + ChatCompletionDelta, ChatCompletionRequest, ChatCompletionResponse, Function, + PartialFunction, + PartialToolCall, + Role, ToolCall, ) @@ -36,12 +42,7 @@ def validate(request: ChatCompletionRequest) -> None: raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.") -def openai_chat_completion( - request: ChatCompletionRequest, - client: Callable[..., Any], -) -> ChatCompletionResponse: - validate(request) - +def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]: # Format the response format parameters to be compatible with OpenAI API if request.json_mode: response_format: dict[str, Any] = {"type": "json_object"} @@ -61,7 +62,6 @@ def openai_chat_completion( elif value is None and key in kwargs: del kwargs[key] - # Iterate over each message and for message in kwargs["messages"]: role = message.get("role", None) # For each ToolMessage, change the "name" field to be named "tool_call_id" instead @@ -84,6 +84,49 @@ def openai_chat_completion( if request.tool_choice is not None and request.tool_choice not in ["none", "auto", "required"]: kwargs["tool_choice"] = {"type": "function", "function": {"name": request.tool_choice}} + return kwargs + + +def process_logprobs(logprobs_content: list[dict[str, Any]]) -> list[dict[str, Any] | list[dict[str, Any]]]: + """Process logprobs content from OpenAI API response. + + Args: + logprobs_content: List of logprob entries from the API response + + Returns: + Processed logprobs list containing either single token info or lists of top token infos + """ + logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = [] + for logprob in logprobs_content: + if logprob.get("top_logprobs", None): + curr_logprob_infos: list[dict[str, Any]] = [] + for top_logprob in logprob.get("top_logprobs", []): + curr_logprob_infos.append( + { + "token": top_logprob.get("token", ""), + "logprob": top_logprob.get("logprob", 0), + "bytes": top_logprob.get("bytes", 0), + } + ) + logprobs_list.append(curr_logprob_infos) + else: + logprobs_list.append( + { + "token": logprob.get("token", ""), + "logprob": logprob.get("logprob", 0), + "bytes": logprob.get("bytes", 0), + } + ) + return logprobs_list + + +def openai_chat_completion( + request: ChatCompletionRequest, + client: Callable[..., Any], +) -> ChatCompletionResponse: + validate(request) + kwargs = format_kwargs(request) + start_time = time.time() response = client(**kwargs) end_time = time.time() @@ -133,28 +176,7 @@ def openai_chat_completion( # Handle logprobs logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None: - logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = [] - for logprob in choice["logprobs"]["content"]: - if logprob.get("top_logprobs", None): - curr_logprob_infos: list[dict[str, Any]] = [] - for top_logprob in logprob.get("top_logprobs", []): - curr_logprob_infos.append( - { - "token": top_logprob.get("token", ""), - "logprob": top_logprob.get("logprob", 0), - "bytes": top_logprob.get("bytes", 0), - } - ) - logprobs_list.append(curr_logprob_infos) - else: - logprobs_list.append( - { - "token": logprob.get("token", ""), - "logprob": logprob.get("logprob", 0), - "bytes": logprob.get("bytes", 0), - } - ) - logprobs = logprobs_list + logprobs = process_logprobs(choice["logprobs"]["content"]) # Handle extras that OpenAI or Azure OpenAI return if choice.get("content_filter_results", None): @@ -195,6 +217,93 @@ def openai_chat_completion( ) +async def openai_chat_completion_stream( + request: ChatCompletionRequest, + client: Callable[..., Any], +) -> AsyncGenerator[ChatCompletionChunk, None]: + validate(request) + kwargs = format_kwargs(request) + + stream = await client(**kwargs) + + async for chunk in stream: + errors = "" + # This kind of a hack. To make this processing generic for clients that do not return the correct + # data structure, we convert the chunk to a dict + if not isinstance(chunk, dict): + chunk = chunk.to_dict() + + choices: list[ChatCompletionChoiceStream] = [] + for choice_index, choice in enumerate(chunk["choices"]): + content = choice.get("delta", {}).get("content", "") + if not content: + content = "" + + role = Role.ASSISTANT + if choice.get("delta", {}).get("role", None): + role = Role(choice["delta"]["role"]) + + # Handle tool calls + tool_calls: list[PartialToolCall] | None = None + if choice["delta"].get("tool_calls", None): + parsed_tool_calls: list[PartialToolCall] = [] + for tool_call in choice["delta"]["tool_calls"]: + tool_name = tool_call.get("function", {}).get("name", None) + if not tool_name: + tool_name = "" + # Check if the tool name is valid (one of the tool names in the request) + if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]: + errors += ( + f"Choice {choice_index}: Tool call {tool_call} has an invalid tool name: {tool_name}\n" + ) + + tool_args = tool_call.get("function", {}).get("arguments", "") + if not tool_args: + tool_args = "" + + tool_id = tool_call.get("id", None) + parsed_tool_calls.append( + PartialToolCall( + id=tool_id, + function=PartialFunction( + name=tool_name, + arguments=tool_args, + ), + ) + ) + tool_calls = parsed_tool_calls + + refusal = None + if choice["delta"].get("refusal", None): + refusal = choice["delta"]["refusal"] + + delta = ChatCompletionDelta( + content=content, + role=role, + tool_calls=tool_calls, + refusal=refusal, + ) + + index = choice.get("index", 0) + finish_reason = choice.get("finish_reason", None) + + # Handle logprobs + logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None + if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None: + logprobs = process_logprobs(choice["logprobs"]["content"]) + + choice_obj = ChatCompletionChoiceStream( + delta=delta, + finish_reason=finish_reason, + logprobs=logprobs, + index=index, + ) + choices.append(choice_obj) + + chunk_obj = ChatCompletionChunk(choices=choices) + yield chunk_obj + + def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]: """Creates a callable that instantiates and uses an OpenAI client. @@ -215,6 +324,19 @@ def client_callable(**kwargs: Any) -> Any: return client_callable +def create_client_callable_stream( + client_class: type[AsyncOpenAI | AsyncAzureOpenAI], **client_args: Any +) -> Callable[..., Any]: + filtered_args = {k: v for k, v in client_args.items() if v is not None} + + def client_callable(**kwargs: Any) -> Coroutine[Any, Any, Any]: + client = client_class(**filtered_args) + stream = client.chat.completions.create(**kwargs) + return stream + + return client_callable + + class InvalidOAIAPITypeError(Exception): """Raised when an invalid OAIAPIType string is provided.""" @@ -227,6 +349,7 @@ def openai_client( azure_endpoint: str | None = None, timeout: float | None = None, max_retries: int | None = None, + async_client: bool = False, ) -> Callable[..., Any]: """Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters. @@ -247,11 +370,11 @@ def openai_client( max_retries (int, optional): Certain errors are automatically retried 2 times by default, with a short exponential backoff. Connection errors (for example, due to a network connectivity problem), 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default. + async_client (bool, optional): Whether to return an async client. Defaults to False. Returns: Callable[..., Any]: A callable that creates a client and returns completion results - Raises: InvalidOAIAPITypeError: If an invalid API type string is provided. NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai'). @@ -260,17 +383,21 @@ def openai_client( raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.") if api_type == "openai": - return create_client_callable( - OpenAI, + client_class = AsyncOpenAI if async_client else OpenAI + callable_creator = create_client_callable_stream if async_client else create_client_callable + return callable_creator( + client_class, # type: ignore api_key=api_key, organization=organization, timeout=timeout, max_retries=max_retries, ) elif api_type == "azure_openai": + azure_client_class = AsyncAzureOpenAI if async_client else AzureOpenAI + callable_creator = create_client_callable_stream if async_client else create_client_callable if api_key: - return create_client_callable( - AzureOpenAI, + return callable_creator( + azure_client_class, # type: ignore api_version=aoai_api_version, azure_endpoint=azure_endpoint, api_key=api_key, @@ -282,8 +409,8 @@ def openai_client( ad_token_provider = get_bearer_token_provider( azure_credential, "https://cognitiveservices.azure.com/.default" ) - return create_client_callable( - AzureOpenAI, + return callable_creator( + azure_client_class, # type: ignore api_version=aoai_api_version, azure_endpoint=azure_endpoint, azure_ad_token_provider=ad_token_provider, diff --git a/src/not_again_ai/llm/chat_completion/types.py b/src/not_again_ai/llm/chat_completion/types.py index 8fa10fd..18ecfbd 100644 --- a/src/not_again_ai/llm/chat_completion/types.py +++ b/src/not_again_ai/llm/chat_completion/types.py @@ -52,12 +52,23 @@ class Function(BaseModel): arguments: dict[str, Any] +class PartialFunction(BaseModel): + name: str + arguments: str + + class ToolCall(BaseModel): id: str function: Function type: Literal["function"] = "function" +class PartialToolCall(BaseModel): + id: str | None + function: PartialFunction + type: Literal["function"] = "function" + + class DeveloperMessage(BaseMessage[str]): role: Literal[Role.DEVELOPER] = Role.DEVELOPER @@ -87,6 +98,7 @@ class ToolMessage(BaseMessage[str]): class ChatCompletionRequest(BaseModel): messages: list[MessageT] model: str + stream: bool = Field(default=False) max_completion_tokens: int | None = Field(default=None) context_window: int | None = Field(default=None) @@ -148,3 +160,35 @@ class ChatCompletionResponse(BaseModel): system_fingerprint: str | None = Field(default=None) extras: Any | None = Field(default=None) + + +class ChatCompletionDelta(BaseModel): + content: str + role: Role = Field(default=Role.ASSISTANT) + + tool_calls: list[PartialToolCall] | None = Field(default=None) + + refusal: str | None = Field(default=None) + + +class ChatCompletionChoiceStream(BaseModel): + delta: ChatCompletionDelta + index: int + finish_reason: Literal["stop", "length", "tool_calls", "content_filter"] | None + + logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None) + + extras: Any | None = Field(default=None) + + +class ChatCompletionChunk(BaseModel): + choices: list[ChatCompletionChoiceStream] + + errors: str = Field(default="") + + completion_tokens: int | None = Field(default=None) + prompt_tokens: int | None = Field(default=None) + response_duration: float | None = Field(default=None) + + system_fingerprint: str | None = Field(default=None) + extras: Any | None = Field(default=None) diff --git a/tests/llm/chat_completion/test_chat_completion_stream.py b/tests/llm/chat_completion/test_chat_completion_stream.py new file mode 100644 index 0000000..2459971 --- /dev/null +++ b/tests/llm/chat_completion/test_chat_completion_stream.py @@ -0,0 +1,393 @@ +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import pytest + +from not_again_ai.llm.chat_completion import chat_completion_stream +from not_again_ai.llm.chat_completion.providers.openai_api import openai_client +from not_again_ai.llm.chat_completion.types import ( + ChatCompletionRequest, + ImageContent, + ImageDetail, + ImageUrl, + MessageT, + SystemMessage, + TextContent, + UserMessage, +) +from not_again_ai.llm.prompting.compile_prompt import encode_image + +image_dir = Path(__file__).parent.parent / "sample_images" +cat_image = image_dir / "cat.jpg" +dog_image = image_dir / "dog.jpg" +numbers_image = image_dir / "numbers.png" +sk_infographic = image_dir / "SKInfographic.png" +sk_diagram = image_dir / "SKDiagram.png" + + +# region Azure OpenAI + + +@pytest.fixture( + params=[ + {"async_client": True}, + {"api_type": "azure_openai", "aoai_api_version": "2025-01-01-preview", "async_client": True}, + ] +) +def openai_aoai_client_fixture(request: pytest.FixtureRequest) -> Callable[..., Any]: + return openai_client(**request.param) + + +async def test_chat_completion_stream_simple(openai_aoai_client_fixture: Callable[..., Any]) -> None: + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=[ + SystemMessage(content="Hello, world!"), + UserMessage(content="What is the capital of France?"), + ], + max_completion_tokens=100, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_json_mode(openai_aoai_client_fixture: Callable[..., Any]) -> None: + messages: list[MessageT] = [ + SystemMessage( + content="""You are getting names of users and formatting them into json. +Example: +User: Jane Doe +Output: {"name": "Jane Doe"}""" + ), + UserMessage(content="John Doe"), + ] + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=messages, + max_completion_tokens=200, + temperature=0, + json_mode=True, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_n(openai_aoai_client_fixture: Callable[..., Any]) -> None: + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Hello!"), + ], + max_completion_tokens=100, + n=2, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_toplogprobs(openai_aoai_client_fixture: Callable[..., Any]) -> None: + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Hello!"), + ], + max_completion_tokens=100, + logprobs=True, + top_logprobs=3, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_tool_simple(openai_aoai_client_fixture: Callable[..., Any]) -> None: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + } + ] + + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=[ + UserMessage( + content="What's the current weather like in Boston, MA today? Call the get_current_weather function." + ) + ], + tools=tools, + max_completion_tokens=300, + temperature=0, + ) + + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_tool_required_name(openai_aoai_client_fixture: Callable[..., Any]) -> None: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + ] + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=[ + UserMessage(content="What's the current weather like in Boston, MA today?"), + ], + tools=tools, + tool_choice="get_current_weather", + max_completion_tokens=300, + temperature=0, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_multiple_tools(openai_aoai_client_fixture: Callable[..., Any]) -> None: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + ] + messages: list[MessageT] = [ + SystemMessage(content="Call the get_current_weather function once for each city that the user mentions."), + UserMessage(content="What's the current weather like in Boston, MA and New York, NY today?"), + ] + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=messages, + tools=tools, + max_completion_tokens=400, + temperature=0, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_message_with_tools(openai_aoai_client_fixture: Callable[..., Any]) -> None: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + ] + messages: list[MessageT] = [ + SystemMessage( + content="""You will be given a function called get_current_weather. +Before calling the function, first reason about which city the user is asking about. YOU MUST think step by step before calling the function. +For example, if the user asks 'What's the current weather like in Boston, MA today?', You should first say 'The user is asking about Boston, MA so I will call the function with 'Boston, MA' """ + ), + UserMessage( + content="What's the current weather like in Boston, MA today? First think step by step as to which city and state to call, only then call the get_current_weather function. " + ), + ] + + request = ChatCompletionRequest( + model="gpt-4o-2024-11-20", + messages=messages, + tools=tools, + max_completion_tokens=600, + temperature=0.7, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_structured_output(openai_aoai_client_fixture: Callable[..., Any]) -> None: + messages: list[MessageT] = [ + SystemMessage(content="You are a helpful assistant"), + UserMessage(content="9.11 and 9.9 -- which is bigger?"), + ] + json_schema = { + "name": "reasoning_schema", + "strict": True, + "schema": { + "type": "object", + "properties": { + "reasoning_steps": { + "type": "array", + "items": {"type": "string"}, + "description": "The reasoning steps leading to the final conclusion.", + }, + "answer": { + "type": "string", + "description": "The final answer, taking into account the reasoning steps.", + }, + }, + "required": ["reasoning_steps", "answer"], + "additionalProperties": False, + }, + "description": "A schema for structured output that includes reasoning steps and the final answer.", + } + + request = ChatCompletionRequest( + messages=messages, + model="gpt-4o-2024-11-20", + max_completion_tokens=400, + temperature=0.2, + json_mode=False, + structured_outputs=json_schema, + ) + + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_vision(openai_aoai_client_fixture: Callable[..., Any]) -> None: + messages: list[MessageT] = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage( + content=[ + TextContent(text="Describe the animal in the image in one word."), + ImageContent( + image_url=ImageUrl(url=f"data:image/jpeg;base64,{encode_image(cat_image)}", detail=ImageDetail.LOW) + ), + ] + ), + ] + + request = ChatCompletionRequest( + messages=messages, + model="gpt-4o-2024-11-20", + max_completion_tokens=200, + temperature=0.5, + ) + + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_vision_tool_call(openai_aoai_client_fixture: Callable[..., Any]) -> None: + messages: list[MessageT] = [ + SystemMessage( + content="""You are detecting if there is text (numbers or letters) in images. +If you see any text, call the ocr tool. It takes no parameters.""" + ), + UserMessage( + content=[ + ImageContent( + image_url=ImageUrl( + url=f"data:image/png;base64,{encode_image(numbers_image)}", detail=ImageDetail.LOW + ) + ), + ] + ), + ] + tools = [ + { + "type": "function", + "function": { + "name": "ocr", + "description": "Perform Optical Character Recognition (OCR) on an image", + "parameters": {}, + }, + }, + ] + + request = ChatCompletionRequest( + messages=messages, + model="gpt-4o-2024-11-20", + tools=tools, + max_completion_tokens=200, + ) + + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_invalid_params(openai_aoai_client_fixture: Callable[..., Any]) -> None: + request = ChatCompletionRequest( + model="gpt-4o-mini-2024-07-18", + messages=[UserMessage(content="What is the capital of France?")], + max_completion_tokens=100, + context_window=1000, + mirostat=1, + ) + async for chunk in chat_completion_stream(request, "openai", openai_aoai_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +# endregion + + +# region Ollama + + +# endregion