Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ollama support [DCH-460] #82

Merged
merged 5 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies = [
"anthropic[bedrock]>=0.37.1",
"httpx==0.27",
"instructor[anthropic]>1.6",
"ollama>=0.4.5",
"openai>=1.58.1"
]

Expand Down Expand Up @@ -66,8 +67,6 @@ target-version = "py311"
[tool.ruff.lint]
select = ["ALL"]
ignore = [
"ANN101", # Self should never be type annotated.
"ANN102", # cls should never be type annotated.
"B008" # Allow function call in arguments; this is common in FastAPI.
]
fixable = ["ALL"]
Expand Down
1 change: 0 additions & 1 deletion src/cloai/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ async def run(self, system_prompt: str, user_prompt: str) -> str:
system=system_prompt,
messages=[{"role": "user", "content": user_prompt}],
)

return message.content[0].text # type: ignore[union-attr]

async def call_instructor(
Expand Down
134 changes: 134 additions & 0 deletions src/cloai/llm/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Ollama LLM client implementation."""

import json
from typing import Any, TypeVar, get_args, get_origin

import ollama
import pydantic

from cloai.llm.utils import LlmBaseClass

T = TypeVar("T")


class OllamaLlm(LlmBaseClass):
"""Client for Ollama API."""

def __init__(
self,
model: str,
base_url: str,
) -> None:
"""Initialize Ollama client.

Args:
model: The model to run, must already be installed on the host via ollama.
base_url: The URL of the Ollama API.
"""
self.model = model
self.client = ollama.AsyncClient(host=base_url)

async def run(self, system_prompt: str, user_prompt: str) -> str:
"""Call Ollama model."""
response = await self.client.chat(
model=self.model,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
],
)
return response["message"]["content"]

async def call_instructor(
self,
response_model: type[T],
system_prompt: str,
user_prompt: str,
max_tokens: int = 4096,
) -> T:
"""Run a type-safe large language model query.

This function uses Pydantic to convert any arbitrary class to JSON
schema. This is unlikely to be fool-proof, but we can deal with issues
as they arise.

Args:
response_model: The Pydantic response model.
system_prompt: The system prompt.
user_prompt: The user prompt.
max_tokens: The maximum number of tokens to allow.

Returns:
The response as the requested object.
"""
default_max_tokens = 4096
if max_tokens != default_max_tokens:
msg = "max_tokens has not yet been implemented in Ollama."
raise NotImplementedError(msg)

Check warning on line 73 in src/cloai/llm/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/cloai/llm/ollama.py#L72-L73

Added lines #L72 - L73 were not covered by tests

# Use Pydantic for converting an arbitrary class to JSON schema.
schema = pydantic.create_model(
response_model.__name__,
field=(response_model, ...),
).model_json_schema()

response = await self.client.chat(
model=self.model,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
],
format=schema,
)

data = json.loads(response.message.content)["field"] # type: ignore[arg-type]
return _model_and_data_to_object(response_model, data)


def _model_and_data_to_object(cls: type[T], data: Any) -> Any: # noqa: ANN401
"""Convert JSON data to the specified type.

Args:
cls: The target class type.
data: The JSON data to convert.

Returns:
An instance of the target class.
"""
# Pydantic models
try:
return cls.model_validate(data) # type: ignore[call-arg, attr-defined]
except AttributeError:

Check warning on line 113 in src/cloai/llm/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/cloai/llm/ollama.py#L113

Added line #L113 was not covered by tests
# Not a Pydantic model.
pass

Check warning on line 115 in src/cloai/llm/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/cloai/llm/ollama.py#L115

Added line #L115 was not covered by tests

# Lists/tuples
if cls in (list, tuple):
return cls(data) # type: ignore[call-arg]

Check warning on line 119 in src/cloai/llm/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/cloai/llm/ollama.py#L118-L119

Added lines #L118 - L119 were not covered by tests

if get_origin(cls) in (list, tuple):
item_types = get_args(cls)
if len(item_types) > 1:
msg = "Only one item type may be present in a list/tuple type."
raise NotImplementedError(msg)
return cls(_model_and_data_to_object(item_types[0], item) for item in data) # type: ignore[call-arg]

Check warning on line 126 in src/cloai/llm/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/cloai/llm/ollama.py#L121-L126

Added lines #L121 - L126 were not covered by tests

# Basic Python types
if cls in (int, float, str, bool):
return cls(data) # type: ignore[call-arg]

Check warning on line 130 in src/cloai/llm/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/cloai/llm/ollama.py#L129-L130

Added lines #L129 - L130 were not covered by tests

# If we get here, we don't know how to handle this type
msg = f"Unable to convert data to type {cls}"
raise ValueError(msg)

Check warning on line 134 in src/cloai/llm/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/cloai/llm/ollama.py#L133-L134

Added lines #L133 - L134 were not covered by tests
50 changes: 35 additions & 15 deletions tests/integration/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pydantic
import pytest

from cloai.llm import bedrock, llm, openai
from cloai.llm import bedrock, llm, ollama, openai

LLM_MODELS = ["openai", "bedrock", "ollama"]


@pytest.fixture
Expand All @@ -30,24 +32,34 @@ def bedrock_anthropic_model() -> llm.LargeLanguageModel:
return llm.LargeLanguageModel(client=client)


@pytest.fixture
def ollama_model() -> llm.LargeLanguageModel:
"""Creates the Ollama client. Requires ollama installed with llama3.2:1b."""
client = ollama.OllamaLlm("llama3.2:1b", "http://localhost:11434")
return llm.LargeLanguageModel(client=client)


@pytest.fixture
def model(
request: pytest.FixtureRequest,
openai_model: llm.LargeLanguageModel,
bedrock_anthropic_model: llm.LargeLanguageModel,
ollama_model: llm.LargeLanguageModel,
) -> llm.LargeLanguageModel:
"""Fetches the LLM."""
name = request.param
if name == "openai":
return openai_model
if name == "bedrock":
return bedrock_anthropic_model
if name == "ollama":
return ollama_model

msg = "Wrong model name."
raise ValueError(msg)


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_run(model: llm.LargeLanguageModel) -> None:
"""Test the run command."""
Expand All @@ -60,27 +72,33 @@ async def test_run(model: llm.LargeLanguageModel) -> None:
assert len(actual) > 0


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.asyncio
async def test_call_instructor(model: llm.LargeLanguageModel) -> None:
"""Test the call_instructor command."""
class Response(pydantic.BaseModel):
"""Testing response model for instructor."""

class Response(pydantic.BaseModel):
grade: int = pydantic.Field(..., lt=10, gt=0)
grade: int = pydantic.Field(..., lt=10, gt=0)


@pytest.mark.parametrize("response", [Response, int])
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_call_instructor(
model: llm.LargeLanguageModel,
response: type[Response] | type[int],
) -> None:
"""Test the call_instructor command."""
system_prompt = "Return the user message."
user_prompt = "{'grade': 3}"

actual = await model.call_instructor(
response_model=Response,
response_model=response,
system_prompt=system_prompt,
user_prompt=user_prompt,
)

assert isinstance(actual, Response)
assert isinstance(actual, response)


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_chain_of_density(model: llm.LargeLanguageModel) -> None:
"""Test the chain_of_density command."""
Expand Down Expand Up @@ -108,7 +126,7 @@ async def test_chain_of_density(model: llm.LargeLanguageModel) -> None:
assert len(actual) > 0


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_chain_of_verification_str(model: llm.LargeLanguageModel) -> None:
"""Test the chain_of_verification command."""
Expand All @@ -122,13 +140,15 @@ async def test_chain_of_verification_str(model: llm.LargeLanguageModel) -> None:
)

assert isinstance(actual, str)
assert "horse" in actual.lower()


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_chain_of_verification_model(model: llm.LargeLanguageModel) -> None:
"""Test the chain_of_verification command."""
"""Test the chain_of_verification command.

This test may be unstable with Ollama depending on the model used.
"""
text = "Lea is 9 years old. She likes riding horses."

class Response(pydantic.BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
but they are the best we can do without connecting to remote servers on every test.
"""

import json
import types
from unittest import mock

import pydantic
import pytest
import pytest_mock

from cloai.llm import bedrock, openai, utils
from cloai.llm import bedrock, ollama, openai, utils

TEST_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
TEST_SYSTEM_PROMPT = "You are a helpful assistant."
TEST_USER_PROMPT = "What is 2+2?"
TEST_RUN_RESPONSE = "Hello world!"

LLM_TYPE = bedrock.AnthropicBedrockLlm | openai.OpenAiLlm | openai.AzureLlm
llms = ("azure", "anthropic_bedrock", "openai")
LLM_TYPE = (
bedrock.AnthropicBedrockLlm | openai.OpenAiLlm | openai.AzureLlm | ollama.OllamaLlm
)
llms = ("azure", "anthropic_bedrock", "openai", "ollama")


class _TestResponse(pydantic.BaseModel):
Expand Down Expand Up @@ -134,6 +137,17 @@ def openai_llm(
)


@pytest.fixture
def ollama_llm(mocker: pytest_mock.MockerFixture) -> ollama.OllamaLlm:
"""Create the mocked anthropic bedrock llm."""
response = {"message": {"content": TEST_RUN_RESPONSE}}
mocker.patch("ollama.AsyncClient.chat", return_value=response)
return ollama.OllamaLlm(
model=TEST_MODEL,
base_url="somethinglocal",
)


@pytest.fixture
def azure_llm(
mock_azure_client: mock.MagicMock,
Expand All @@ -154,6 +168,7 @@ def llm(
openai_llm: openai.OpenAiLlm,
azure_llm: openai.AzureLlm,
anthropic_bedrock_llm: bedrock.AnthropicBedrockLlm,
ollama_llm: ollama.OllamaLlm,
) -> utils.LlmBaseClass:
"""Create the mocked llm."""
name = request.param
Expand All @@ -163,6 +178,8 @@ def llm(
return anthropic_bedrock_llm
if name == "azure":
return azure_llm
if name == "ollama":
return ollama_llm
raise NotImplementedError


Expand Down Expand Up @@ -201,7 +218,17 @@ async def test_call_instructor_method(
) -> None:
"""Test the call_instructor method."""
expected_response = _TestResponse(answer="4")
llm._instructor.chat.completions.create.return_value = expected_response # type: ignore[call-overload, attr-defined]
if isinstance(llm, ollama.OllamaLlm):
# Ollama doesn't use instructor and therefore requires custom handling.
class Content(pydantic.BaseModel):
content: str = json.dumps({"field": _TestResponse(answer="4").model_dump()})

class Response(pydantic.BaseModel):
message: Content = Content()

llm.client.chat.return_value = Response() # type: ignore[attr-defined]
else:
llm._instructor.chat.completions.create.return_value = expected_response # type: ignore[call-overload, attr-defined]

result = await llm.call_instructor(
_TestResponse,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llm/test_llm_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_recursive_pydantic_model_dump_primitive() -> None:
assert actual == expected


def test_recursive_pydantic_model_dump_recusive() -> None:
def test_recursive_pydantic_model_dump_recursive() -> None:
"""Test dumping a model containing a model."""
model = ModelRecursive()
expected = model.model_dump()
Expand Down
Loading