diff --git a/notebooks/llm/03_llm_streaming.ipynb b/notebooks/llm/03_llm_streaming.ipynb new file mode 100644 index 0000000..828824c --- /dev/null +++ b/notebooks/llm/03_llm_streaming.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Streaming APIs\n", + "\n", + "This notebooks shows examples for how to use the streaming APIs for both OpenAI and Ollama.\n", + "\n", + "When instantiating the client, set `async_client=True` to use the async client which supports streaming.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from not_again_ai.llm.chat_completion import chat_completion_stream\n", + "from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_client\n", + "from not_again_ai.llm.chat_completion.providers.openai_api import openai_client\n", + "from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, SystemMessage, UserMessage\n", + "\n", + "openai_client = openai_client(async_client=True)\n", + "ollama_client = ollama_client(async_client=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'index': 0}], 'errors': '', 'response_duration': 1.5805}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': 'call_W5ScLS0Atqtt4awudPS8pA6G', 'function': {'name': 'get_current_weather', 'arguments': ''}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5835}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '{\"lo'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5846}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'catio'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5858}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'n\": \"B'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5868}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'osto'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.589}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'n, MA'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5899}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '\", \"fo'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5909}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'rmat'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5918}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '\": \"f'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5934}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'ahrenh'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5942}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'eit\"'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5955}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '}'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5964}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': 'call_sETMkonfyDQP5ndr8kGpHVcY', 'function': {'name': 'get_current_weather', 'arguments': ''}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5974}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '{\"lo'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5983}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'catio'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5998}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'n\": \"N'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6006}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'ew Y'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6383}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'ork, '}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6393}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'NY\", \"'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6403}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'form'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6413}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'at\": '}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6422}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '\"fahre'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6431}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'nhei'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6449}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 't\"}'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6458}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'index': 0, 'finish_reason': 'tool_calls'}], 'errors': '', 'response_duration': 1.647}\n", + "{'choices': [], 'errors': '', 'completion_tokens': 62, 'prompt_tokens': 112, 'response_duration': 1.6504, 'system_fingerprint': 'fp_72ed7ab54c'}\n" + ] + } + ], + "source": [ + "# OpenAI Streaming Example with tools\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"format\": {\n", + " \"type\": \"string\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n", + " },\n", + " },\n", + " \"required\": [\"location\", \"format\"],\n", + " },\n", + " },\n", + " },\n", + "]\n", + "messages = [\n", + " SystemMessage(content=\"Call the get_current_weather function once for each city that the user mentions.\"),\n", + " UserMessage(content=\"What's the current weather like in Boston, MA and New York, NY today?\"),\n", + "]\n", + "request = ChatCompletionRequest(\n", + " model=\"gpt-4o-mini-2024-07-18\",\n", + " messages=messages,\n", + " tools=tools,\n", + " max_completion_tokens=400,\n", + " temperature=0,\n", + ")\n", + "\n", + "async for chunk in chat_completion_stream(request, \"openai\", openai_client):\n", + " print(chunk.model_dump(mode=\"json\", exclude_none=True), flush=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': '', 'function': {'name': 'get_current_weather', 'arguments': {'location': 'Boston, MA'}}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 0.4406}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': '', 'function': {'name': 'get_current_weather', 'arguments': {'location': 'New York, NY'}}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 0.7869}\n", + "{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'index': 0, 'finish_reason': 'stop'}], 'errors': '', 'completion_tokens': 47, 'prompt_tokens': 204, 'response_duration': 0.8152}\n" + ] + } + ], + "source": [ + "# Ollama Streaming Example with tools\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"format\": {\n", + " \"type\": \"string\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n", + " },\n", + " },\n", + " \"required\": [\"location\", \"format\"],\n", + " },\n", + " },\n", + " },\n", + "]\n", + "messages = [\n", + " SystemMessage(content=\"Call the get_current_weather function once for each city that the user mentions.\"),\n", + " UserMessage(content=\"What's the current weather like in Boston, MA and New York, NY today?\"),\n", + "]\n", + "request = ChatCompletionRequest(\n", + " model=\"qwen2.5:14b\",\n", + " messages=messages,\n", + " tools=tools,\n", + " max_completion_tokens=400,\n", + " temperature=0,\n", + ")\n", + "async for chunk in chat_completion_stream(request, \"ollama\", ollama_client):\n", + " print(chunk.model_dump(mode=\"json\", exclude_none=True))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/noxfile.py b/noxfile.py index 420cf82..5492c5c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -17,7 +17,7 @@ @session(python=["3.11", "3.12"]) def test(s: Session) -> None: - s.install(".[data,llm,statistics,viz]", "pytest", "pytest-cov", "pytest-randomly") + s.install(".[data,llm,statistics,viz]", "pytest", "pytest-asyncio", "pytest-cov", "pytest-randomly") # Skip tests in directories specified by the SKIP_TESTS_NAII environment variable. skip_tests = os.getenv("SKIP_TESTS_NAAI", "") diff --git a/pyproject.toml b/pyproject.toml index f4ff05d..85e7a10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,7 @@ filterwarnings = [ "ignore::pytest.PytestUnraisableExceptionWarning" ] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" [tool.coverage.run] branch = true \ No newline at end of file diff --git a/src/not_again_ai/llm/chat_completion/interface.py b/src/not_again_ai/llm/chat_completion/interface.py index 522b03c..3342b9d 100644 --- a/src/not_again_ai/llm/chat_completion/interface.py +++ b/src/not_again_ai/llm/chat_completion/interface.py @@ -1,7 +1,7 @@ from collections.abc import AsyncGenerator, Callable from typing import Any -from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion +from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion, ollama_chat_completion_stream from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse @@ -54,5 +54,8 @@ async def chat_completion_stream( if provider == "openai" or provider == "azure_openai": async for chunk in openai_chat_completion_stream(request, client): yield chunk + elif provider == "ollama": + async for chunk in ollama_chat_completion_stream(request, client): + yield chunk else: raise ValueError(f"Provider {provider} not supported") diff --git a/src/not_again_ai/llm/chat_completion/providers/ollama_api.py b/src/not_again_ai/llm/chat_completion/providers/ollama_api.py index 065abe6..ca05ec4 100644 --- a/src/not_again_ai/llm/chat_completion/providers/ollama_api.py +++ b/src/not_again_ai/llm/chat_completion/providers/ollama_api.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable import json import os import re @@ -6,14 +6,20 @@ from typing import Any, Literal, cast from loguru import logger -from ollama import ChatResponse, Client, ResponseError +from ollama import AsyncClient, ChatResponse, Client, ResponseError from not_again_ai.llm.chat_completion.types import ( AssistantMessage, ChatCompletionChoice, + ChatCompletionChoiceStream, + ChatCompletionChunk, + ChatCompletionDelta, ChatCompletionRequest, ChatCompletionResponse, Function, + PartialFunction, + PartialToolCall, + Role, ToolCall, ) @@ -51,14 +57,8 @@ def validate(request: ChatCompletionRequest) -> None: raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.") -def ollama_chat_completion( - request: ChatCompletionRequest, - client: Callable[..., Any], -) -> ChatCompletionResponse: - validate(request) - +def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]: kwargs = request.model_dump(mode="json", exclude_none=True) - # For each key in OLLAMA_PARAMETER_MAP # If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP # If it is None, remove that key from kwargs @@ -141,6 +141,16 @@ def ollama_chat_completion( logger.warning("Ollama model only supports a single image per message. Using only the first images.") message["images"] = images + return kwargs + + +def ollama_chat_completion( + request: ChatCompletionRequest, + client: Callable[..., Any], +) -> ChatCompletionResponse: + validate(request) + kwargs = format_kwargs(request) + try: start_time = time.time() response: ChatResponse = client(**kwargs) @@ -164,7 +174,7 @@ def ollama_chat_completion( tool_name = tool_call.function.name if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]: errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n" - tool_args = tool_call.function.arguments + tool_args = dict(tool_call.function.arguments) parsed_tool_calls.append( ToolCall( id="", @@ -206,7 +216,65 @@ def ollama_chat_completion( ) -def ollama_client(host: str | None = None, timeout: float | None = None) -> Callable[..., Any]: +async def ollama_chat_completion_stream( + request: ChatCompletionRequest, + client: Callable[..., Any], +) -> AsyncGenerator[ChatCompletionChunk, None]: + validate(request) + kwargs = format_kwargs(request) + + start_time = time.time() + stream = await client(**kwargs) + + async for chunk in stream: + errors = "" + # Handle tool calls + tool_calls: list[PartialToolCall] | None = None + if chunk.message.tool_calls: + parsed_tool_calls: list[PartialToolCall] = [] + for tool_call in chunk.message.tool_calls: + tool_name = tool_call.function.name + if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]: + errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n" + tool_args = tool_call.function.arguments + + parsed_tool_calls.append( + PartialToolCall( + id="", + function=PartialFunction( + name=tool_name, + arguments=tool_args, + ), + ) + ) + tool_calls = parsed_tool_calls + + current_time = time.time() + response_duration = round(current_time - start_time, 4) + + delta = ChatCompletionDelta( + content=chunk.message.content or "", + role=Role.ASSISTANT, + tool_calls=tool_calls, + ) + choice_obj = ChatCompletionChoiceStream( + delta=delta, + finish_reason=chunk.done_reason, + index=0, + ) + chunk_obj = ChatCompletionChunk( + choices=[choice_obj], + errors=errors.strip(), + completion_tokens=chunk.get("eval_count", None), + prompt_tokens=chunk.get("prompt_eval_count", None), + response_duration=response_duration, + ) + yield chunk_obj + + +def ollama_client( + host: str | None = None, timeout: float | None = None, async_client: bool = False +) -> Callable[..., Any]: """Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable. Args: @@ -226,7 +294,7 @@ def ollama_client(host: str | None = None, timeout: float | None = None) -> Call host = "http://localhost:11434" def client_callable(**kwargs: Any) -> Any: - client = Client(host=host, timeout=timeout) + client = AsyncClient(host=host, timeout=timeout) if async_client else Client(host=host, timeout=timeout) return client.chat(**kwargs) return client_callable diff --git a/src/not_again_ai/llm/chat_completion/providers/openai_api.py b/src/not_again_ai/llm/chat_completion/providers/openai_api.py index d252c11..e50b039 100644 --- a/src/not_again_ai/llm/chat_completion/providers/openai_api.py +++ b/src/not_again_ai/llm/chat_completion/providers/openai_api.py @@ -224,6 +224,7 @@ async def openai_chat_completion_stream( validate(request) kwargs = format_kwargs(request) + start_time = time.time() stream = await client(**kwargs) async for chunk in stream: @@ -234,7 +235,7 @@ async def openai_chat_completion_stream( chunk = chunk.to_dict() choices: list[ChatCompletionChoiceStream] = [] - for choice_index, choice in enumerate(chunk["choices"]): + for choice in chunk["choices"]: content = choice.get("delta", {}).get("content", "") if not content: content = "" @@ -251,12 +252,6 @@ async def openai_chat_completion_stream( tool_name = tool_call.get("function", {}).get("name", None) if not tool_name: tool_name = "" - # Check if the tool name is valid (one of the tool names in the request) - if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]: - errors += ( - f"Choice {choice_index}: Tool call {tool_call} has an invalid tool name: {tool_name}\n" - ) - tool_args = tool_call.get("function", {}).get("arguments", "") if not tool_args: tool_args = "" @@ -300,7 +295,26 @@ async def openai_chat_completion_stream( ) choices.append(choice_obj) - chunk_obj = ChatCompletionChunk(choices=choices) + current_time = time.time() + response_duration = round(current_time - start_time, 4) + + if "usage" in chunk and chunk["usage"] is not None: + completion_tokens = chunk["usage"].get("completion_tokens", None) + prompt_tokens = chunk["usage"].get("prompt_tokens", None) + system_fingerprint = chunk.get("system_fingerprint", None) + else: + completion_tokens = None + prompt_tokens = None + system_fingerprint = None + + chunk_obj = ChatCompletionChunk( + choices=choices, + errors=errors.strip(), + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + response_duration=response_duration, + system_fingerprint=system_fingerprint, + ) yield chunk_obj @@ -331,6 +345,7 @@ def create_client_callable_stream( def client_callable(**kwargs: Any) -> Coroutine[Any, Any, Any]: client = client_class(**filtered_args) + kwargs["stream_options"] = {"include_usage": True} stream = client.chat.completions.create(**kwargs) return stream diff --git a/src/not_again_ai/llm/chat_completion/types.py b/src/not_again_ai/llm/chat_completion/types.py index 18ecfbd..554686f 100644 --- a/src/not_again_ai/llm/chat_completion/types.py +++ b/src/not_again_ai/llm/chat_completion/types.py @@ -54,7 +54,7 @@ class Function(BaseModel): class PartialFunction(BaseModel): name: str - arguments: str + arguments: str | dict[str, Any] class ToolCall(BaseModel): diff --git a/tests/llm/chat_completion/test_chat_completion_stream.py b/tests/llm/chat_completion/test_chat_completion_stream.py index 2459971..6940148 100644 --- a/tests/llm/chat_completion/test_chat_completion_stream.py +++ b/tests/llm/chat_completion/test_chat_completion_stream.py @@ -5,15 +5,20 @@ import pytest from not_again_ai.llm.chat_completion import chat_completion_stream +from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_client from not_again_ai.llm.chat_completion.providers.openai_api import openai_client from not_again_ai.llm.chat_completion.types import ( + AssistantMessage, ChatCompletionRequest, + Function, ImageContent, ImageDetail, ImageUrl, MessageT, SystemMessage, TextContent, + ToolCall, + ToolMessage, UserMessage, ) from not_again_ai.llm.prompting.compile_prompt import encode_image @@ -388,6 +393,161 @@ async def test_chat_completion_stream_invalid_params(openai_aoai_client_fixture: # region Ollama +@pytest.fixture( + params=[ + {"async_client": True}, + ] +) +def ollama_client_fixture(request: pytest.FixtureRequest) -> Callable[..., Any]: + return ollama_client(**request.param) + + +async def test_chat_completion_stream_ollama(ollama_client_fixture: Callable[..., Any]) -> None: + request = ChatCompletionRequest( + model="llama3.2-vision:11b-instruct-q4_K_M", + messages=[ + SystemMessage(content="Hello, world!"), + UserMessage(content="What is the capital of France?"), + ], + max_completion_tokens=100, + frequency_penalty=1.2, + top_p=0.8, + context_window=1000, + ) + async for chunk in chat_completion_stream(request, "ollama", ollama_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_ollama_structured_output(ollama_client_fixture: Callable[..., Any]) -> None: + messages: list[MessageT] = [ + SystemMessage(content="You are a helpful assistant"), + UserMessage(content="9.11 and 9.9 -- which is bigger?"), + ] + json_schema = { + "name": "reasoning_schema", + "strict": True, + "schema": { + "type": "object", + "properties": { + "reasoning_steps": { + "type": "array", + "items": {"type": "string"}, + "description": "The reasoning steps leading to the final conclusion.", + }, + "answer": { + "type": "string", + "description": "The final answer, taking into account the reasoning steps.", + }, + }, + "required": ["reasoning_steps", "answer"], + "additionalProperties": False, + }, + "description": "A schema for structured output that includes reasoning steps and the final answer.", + } + + request = ChatCompletionRequest( + messages=messages, + model="llama3.2-vision:11b-instruct-q4_K_M", + max_completion_tokens=400, + temperature=0.2, + json_mode=False, + structured_outputs=json_schema, + ) + + async for chunk in chat_completion_stream(request, "ollama", ollama_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_ollama_multiple_tools(ollama_client_fixture: Callable[..., Any]) -> None: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + ] + messages: list[MessageT] = [ + SystemMessage(content="Call the get_current_weather function once for each city that the user mentions."), + UserMessage(content="What's the current weather like in Boston, MA and New York, NY today?"), + ] + request = ChatCompletionRequest( + model="qwen2.5:14b", + messages=messages, + tools=tools, + max_completion_tokens=400, + temperature=0, + ) + async for chunk in chat_completion_stream(request, "ollama", ollama_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_ollama_tool_message(ollama_client_fixture: Callable[..., Any]) -> None: + messages: list[MessageT] = [ + SystemMessage(content="You are a helpful assistant"), + UserMessage(content="What is the weather in Boston, MA?"), + AssistantMessage( + content="", + tool_calls=[ + ToolCall( + id="abc123", + function=Function( + name="get_current_weather", + arguments={"location": "Boston, MA"}, + ), + ) + ], + ), + ToolMessage(name="abc123", content="The weather in Boston, MA is 70 degrees Fahrenheit."), + ] + request = ChatCompletionRequest( + model="qwen2.5:14b", + messages=messages, + max_completion_tokens=300, + temperature=0.3, + ) + async for chunk in chat_completion_stream(request, "ollama", ollama_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) + + +async def test_chat_completion_stream_ollama_vision_multiple_images(ollama_client_fixture: Callable[..., Any]) -> None: + messages: list[MessageT] = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage( + content=[ + TextContent(text="What are the animals in the images? Reply in one word for each animal."), + ImageContent( + image_url=ImageUrl(url=f"data:image/jpeg;base64,{encode_image(cat_image)}", detail=ImageDetail.LOW) + ), + ImageContent( + image_url=ImageUrl(url=f"data:image/jpeg;base64,{encode_image(dog_image)}", detail=ImageDetail.LOW) + ), + ] + ), + ] + request = ChatCompletionRequest( + messages=messages, + model="llama3.2-vision:11b-instruct-q4_K_M", + max_completion_tokens=100, + ) + async for chunk in chat_completion_stream(request, "ollama", ollama_client_fixture): + print(chunk.model_dump(mode="json", exclude_none=True)) # endregion