Skip to content

Commit

Permalink
ollama streaming api
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidKoleczek committed Feb 8, 2025
1 parent 84ead84 commit 14c03c3
Show file tree
Hide file tree
Showing 8 changed files with 458 additions and 23 deletions.
188 changes: 188 additions & 0 deletions notebooks/llm/03_llm_streaming.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Streaming APIs\n",
"\n",
"This notebooks shows examples for how to use the streaming APIs for both OpenAI and Ollama.\n",
"\n",
"When instantiating the client, set `async_client=True` to use the async client which supports streaming.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from not_again_ai.llm.chat_completion import chat_completion_stream\n",
"from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_client\n",
"from not_again_ai.llm.chat_completion.providers.openai_api import openai_client\n",
"from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, SystemMessage, UserMessage\n",
"\n",
"openai_client = openai_client(async_client=True)\n",
"ollama_client = ollama_client(async_client=True)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'index': 0}], 'errors': '', 'response_duration': 1.5805}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': 'call_W5ScLS0Atqtt4awudPS8pA6G', 'function': {'name': 'get_current_weather', 'arguments': ''}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5835}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '{\"lo'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5846}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'catio'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5858}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'n\": \"B'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5868}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'osto'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.589}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'n, MA'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5899}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '\", \"fo'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5909}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'rmat'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5918}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '\": \"f'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5934}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'ahrenh'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5942}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'eit\"'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5955}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '}'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5964}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': 'call_sETMkonfyDQP5ndr8kGpHVcY', 'function': {'name': 'get_current_weather', 'arguments': ''}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5974}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '{\"lo'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5983}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'catio'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.5998}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'n\": \"N'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6006}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'ew Y'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6383}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'ork, '}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6393}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'NY\", \"'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6403}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'form'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6413}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'at\": '}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6422}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': '\"fahre'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6431}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 'nhei'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6449}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'function': {'name': '', 'arguments': 't\"}'}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 1.6458}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'index': 0, 'finish_reason': 'tool_calls'}], 'errors': '', 'response_duration': 1.647}\n",
"{'choices': [], 'errors': '', 'completion_tokens': 62, 'prompt_tokens': 112, 'response_duration': 1.6504, 'system_fingerprint': 'fp_72ed7ab54c'}\n"
]
}
],
"source": [
"# OpenAI Streaming Example with tools\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"format\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
" },\n",
" },\n",
" \"required\": [\"location\", \"format\"],\n",
" },\n",
" },\n",
" },\n",
"]\n",
"messages = [\n",
" SystemMessage(content=\"Call the get_current_weather function once for each city that the user mentions.\"),\n",
" UserMessage(content=\"What's the current weather like in Boston, MA and New York, NY today?\"),\n",
"]\n",
"request = ChatCompletionRequest(\n",
" model=\"gpt-4o-mini-2024-07-18\",\n",
" messages=messages,\n",
" tools=tools,\n",
" max_completion_tokens=400,\n",
" temperature=0,\n",
")\n",
"\n",
"async for chunk in chat_completion_stream(request, \"openai\", openai_client):\n",
" print(chunk.model_dump(mode=\"json\", exclude_none=True), flush=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': '', 'function': {'name': 'get_current_weather', 'arguments': {'location': 'Boston, MA'}}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 0.4406}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant', 'tool_calls': [{'id': '', 'function': {'name': 'get_current_weather', 'arguments': {'location': 'New York, NY'}}, 'type': 'function'}]}, 'index': 0}], 'errors': '', 'response_duration': 0.7869}\n",
"{'choices': [{'delta': {'content': '', 'role': 'assistant'}, 'index': 0, 'finish_reason': 'stop'}], 'errors': '', 'completion_tokens': 47, 'prompt_tokens': 204, 'response_duration': 0.8152}\n"
]
}
],
"source": [
"# Ollama Streaming Example with tools\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"format\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
" },\n",
" },\n",
" \"required\": [\"location\", \"format\"],\n",
" },\n",
" },\n",
" },\n",
"]\n",
"messages = [\n",
" SystemMessage(content=\"Call the get_current_weather function once for each city that the user mentions.\"),\n",
" UserMessage(content=\"What's the current weather like in Boston, MA and New York, NY today?\"),\n",
"]\n",
"request = ChatCompletionRequest(\n",
" model=\"qwen2.5:14b\",\n",
" messages=messages,\n",
" tools=tools,\n",
" max_completion_tokens=400,\n",
" temperature=0,\n",
")\n",
"async for chunk in chat_completion_stream(request, \"ollama\", ollama_client):\n",
" print(chunk.model_dump(mode=\"json\", exclude_none=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@session(python=["3.11", "3.12"])
def test(s: Session) -> None:
s.install(".[data,llm,statistics,viz]", "pytest", "pytest-cov", "pytest-randomly")
s.install(".[data,llm,statistics,viz]", "pytest", "pytest-asyncio", "pytest-cov", "pytest-randomly")

# Skip tests in directories specified by the SKIP_TESTS_NAII environment variable.
skip_tests = os.getenv("SKIP_TESTS_NAAI", "")
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ filterwarnings = [
"ignore::pytest.PytestUnraisableExceptionWarning"
]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

[tool.coverage.run]
branch = true
5 changes: 4 additions & 1 deletion src/not_again_ai/llm/chat_completion/interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import AsyncGenerator, Callable
from typing import Any

from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion
from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion, ollama_chat_completion_stream
from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream
from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse

Expand Down Expand Up @@ -54,5 +54,8 @@ async def chat_completion_stream(
if provider == "openai" or provider == "azure_openai":
async for chunk in openai_chat_completion_stream(request, client):
yield chunk
elif provider == "ollama":
async for chunk in ollama_chat_completion_stream(request, client):
yield chunk
else:
raise ValueError(f"Provider {provider} not supported")
92 changes: 80 additions & 12 deletions src/not_again_ai/llm/chat_completion/providers/ollama_api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
import json
import os
import re
import time
from typing import Any, Literal, cast

from loguru import logger
from ollama import ChatResponse, Client, ResponseError
from ollama import AsyncClient, ChatResponse, Client, ResponseError

from not_again_ai.llm.chat_completion.types import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionChoiceStream,
ChatCompletionChunk,
ChatCompletionDelta,
ChatCompletionRequest,
ChatCompletionResponse,
Function,
PartialFunction,
PartialToolCall,
Role,
ToolCall,
)

Expand Down Expand Up @@ -51,14 +57,8 @@ def validate(request: ChatCompletionRequest) -> None:
raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.")


def ollama_chat_completion(
request: ChatCompletionRequest,
client: Callable[..., Any],
) -> ChatCompletionResponse:
validate(request)

def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]:
kwargs = request.model_dump(mode="json", exclude_none=True)

# For each key in OLLAMA_PARAMETER_MAP
# If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP
# If it is None, remove that key from kwargs
Expand Down Expand Up @@ -141,6 +141,16 @@ def ollama_chat_completion(
logger.warning("Ollama model only supports a single image per message. Using only the first images.")
message["images"] = images

return kwargs


def ollama_chat_completion(
request: ChatCompletionRequest,
client: Callable[..., Any],
) -> ChatCompletionResponse:
validate(request)
kwargs = format_kwargs(request)

try:
start_time = time.time()
response: ChatResponse = client(**kwargs)
Expand All @@ -164,7 +174,7 @@ def ollama_chat_completion(
tool_name = tool_call.function.name
if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
tool_args = tool_call.function.arguments
tool_args = dict(tool_call.function.arguments)
parsed_tool_calls.append(
ToolCall(
id="",
Expand Down Expand Up @@ -206,7 +216,65 @@ def ollama_chat_completion(
)


def ollama_client(host: str | None = None, timeout: float | None = None) -> Callable[..., Any]:
async def ollama_chat_completion_stream(
request: ChatCompletionRequest,
client: Callable[..., Any],
) -> AsyncGenerator[ChatCompletionChunk, None]:
validate(request)
kwargs = format_kwargs(request)

start_time = time.time()
stream = await client(**kwargs)

async for chunk in stream:
errors = ""
# Handle tool calls
tool_calls: list[PartialToolCall] | None = None
if chunk.message.tool_calls:
parsed_tool_calls: list[PartialToolCall] = []
for tool_call in chunk.message.tool_calls:
tool_name = tool_call.function.name
if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
tool_args = tool_call.function.arguments

parsed_tool_calls.append(
PartialToolCall(
id="",
function=PartialFunction(
name=tool_name,
arguments=tool_args,
),
)
)
tool_calls = parsed_tool_calls

current_time = time.time()
response_duration = round(current_time - start_time, 4)

delta = ChatCompletionDelta(
content=chunk.message.content or "",
role=Role.ASSISTANT,
tool_calls=tool_calls,
)
choice_obj = ChatCompletionChoiceStream(
delta=delta,
finish_reason=chunk.done_reason,
index=0,
)
chunk_obj = ChatCompletionChunk(
choices=[choice_obj],
errors=errors.strip(),
completion_tokens=chunk.get("eval_count", None),
prompt_tokens=chunk.get("prompt_eval_count", None),
response_duration=response_duration,
)
yield chunk_obj


def ollama_client(
host: str | None = None, timeout: float | None = None, async_client: bool = False
) -> Callable[..., Any]:
"""Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable.
Args:
Expand All @@ -226,7 +294,7 @@ def ollama_client(host: str | None = None, timeout: float | None = None) -> Call
host = "http://localhost:11434"

def client_callable(**kwargs: Any) -> Any:
client = Client(host=host, timeout=timeout)
client = AsyncClient(host=host, timeout=timeout) if async_client else Client(host=host, timeout=timeout)
return client.chat(**kwargs)

return client_callable
Loading

0 comments on commit 14c03c3

Please sign in to comment.