Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ollama support [DCH-460] #82

Merged
merged 5 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@ client = cloai.OpenAiLlm(api_key="your_key", model="gpt-4o")

```python
import cloai
import instructor

client = cloai.OpenAiLlm(
api_key="your_key", model="llama3.2",
client = cloai.OllamaLlm(
model="llama3.2",
base_url="http://localhost:11434/v1",
instructor_mode=instructor.Mode.JSON
)
```

Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies = [
"anthropic[bedrock]>=0.37.1",
"httpx==0.27",
"instructor[anthropic]>1.6",
"ollama>=0.4.5",
"openai>=1.58.1"
]

Expand Down Expand Up @@ -66,8 +67,6 @@ target-version = "py311"
[tool.ruff.lint]
select = ["ALL"]
ignore = [
"ANN101", # Self should never be type annotated.
"ANN102", # cls should never be type annotated.
"B008" # Allow function call in arguments; this is common in FastAPI.
]
fixable = ["ALL"]
Expand Down
9 changes: 8 additions & 1 deletion src/cloai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from cloai.llm.bedrock import AnthropicBedrockLlm
from cloai.llm.llm import LargeLanguageModel
from cloai.llm.ollama import OllamaLlm
from cloai.llm.openai import AzureLlm, OpenAiLlm

__all__ = ("AnthropicBedrockLlm", "AzureLlm", "LargeLanguageModel", "OpenAiLlm")
__all__ = (
"AnthropicBedrockLlm",
"AzureLlm",
"LargeLanguageModel",
"OllamaLlm",
"OpenAiLlm",
)
1 change: 0 additions & 1 deletion src/cloai/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ async def run(self, system_prompt: str, user_prompt: str) -> str:
system=system_prompt,
messages=[{"role": "user", "content": user_prompt}],
)

return message.content[0].text # type: ignore[union-attr]

async def call_instructor(
Expand Down
134 changes: 134 additions & 0 deletions src/cloai/llm/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Ollama LLM client implementation."""

import json
from typing import Any, TypeVar, get_args, get_origin

import ollama
import pydantic

from cloai.llm.utils import LlmBaseClass

T = TypeVar("T")


class OllamaLlm(LlmBaseClass):
"""Client for Ollama API."""

def __init__(
self,
model: str,
base_url: str,
) -> None:
"""Initialize Ollama client.

Args:
model: The model to run, must already be installed on the host via ollama.
base_url: The URL of the Ollama API.
"""
self.model = model
self.client = ollama.AsyncClient(host=base_url)

async def run(self, system_prompt: str, user_prompt: str) -> str:
"""Call Ollama model."""
response = await self.client.chat(
model=self.model,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
],
)
return response["message"]["content"]

async def call_instructor(
self,
response_model: type[T],
system_prompt: str,
user_prompt: str,
max_tokens: int = 4096,
) -> T:
"""Run a type-safe large language model query.

This function uses Pydantic to convert any arbitrary class to JSON
schema. This is unlikely to be fool-proof, but we can deal with issues
as they arise.

Args:
response_model: The Pydantic response model.
system_prompt: The system prompt.
user_prompt: The user prompt.
max_tokens: The maximum number of tokens to allow.

Returns:
The response as the requested object.
"""
default_max_tokens = 4096
if max_tokens != default_max_tokens:
msg = "max_tokens has not yet been implemented in Ollama."
raise NotImplementedError(msg)

# Use Pydantic for converting an arbitrary class to JSON schema.
schema = pydantic.create_model(
response_model.__name__,
field=(response_model, ...),
).model_json_schema()

response = await self.client.chat(
model=self.model,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
],
format=schema,
)

data = json.loads(response.message.content)["field"] # type: ignore[arg-type]
return _model_and_data_to_object(response_model, data)


def _model_and_data_to_object(cls: type[T], data: Any) -> Any: # noqa: ANN401
"""Convert JSON data to the specified type.

Args:
cls: The target class type.
data: The JSON data to convert.

Returns:
An instance of the target class.
"""
# Pydantic models
try:
return cls.model_validate(data) # type: ignore[call-arg, attr-defined]
except AttributeError:
# Not a Pydantic model.
pass

# Lists/tuples
if cls in (list, tuple):
return cls(data) # type: ignore[call-arg]

if get_origin(cls) in (list, tuple):
item_types = get_args(cls)
if len(item_types) > 1:
msg = "Only one item type may be present in a list/tuple type."
raise NotImplementedError(msg)
return cls(_model_and_data_to_object(item_types[0], item) for item in data) # type: ignore[call-arg]

# Basic Python types
if cls in (int, float, str, bool):
return cls(data) # type: ignore[call-arg]

# If we get here, we don't know how to handle this type
msg = f"Unable to convert data to type {cls}"
raise ValueError(msg)
50 changes: 35 additions & 15 deletions tests/integration/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pydantic
import pytest

from cloai.llm import bedrock, llm, openai
from cloai.llm import bedrock, llm, ollama, openai

LLM_MODELS = ["openai", "bedrock", "ollama"]


@pytest.fixture
Expand All @@ -30,24 +32,34 @@ def bedrock_anthropic_model() -> llm.LargeLanguageModel:
return llm.LargeLanguageModel(client=client)


@pytest.fixture
def ollama_model() -> llm.LargeLanguageModel:
"""Creates the Ollama client. Requires ollama installed with llama3.2:1b."""
client = ollama.OllamaLlm("llama3.2:1b", "http://localhost:11434")
return llm.LargeLanguageModel(client=client)


@pytest.fixture
def model(
request: pytest.FixtureRequest,
openai_model: llm.LargeLanguageModel,
bedrock_anthropic_model: llm.LargeLanguageModel,
ollama_model: llm.LargeLanguageModel,
) -> llm.LargeLanguageModel:
"""Fetches the LLM."""
name = request.param
if name == "openai":
return openai_model
if name == "bedrock":
return bedrock_anthropic_model
if name == "ollama":
return ollama_model

msg = "Wrong model name."
raise ValueError(msg)


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_run(model: llm.LargeLanguageModel) -> None:
"""Test the run command."""
Expand All @@ -60,27 +72,33 @@ async def test_run(model: llm.LargeLanguageModel) -> None:
assert len(actual) > 0


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.asyncio
async def test_call_instructor(model: llm.LargeLanguageModel) -> None:
"""Test the call_instructor command."""
class Response(pydantic.BaseModel):
"""Testing response model for instructor."""

class Response(pydantic.BaseModel):
grade: int = pydantic.Field(..., lt=10, gt=0)
grade: int = pydantic.Field(..., lt=10, gt=0)


@pytest.mark.parametrize("response", [Response, int])
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_call_instructor(
model: llm.LargeLanguageModel,
response: type[Response] | type[int],
) -> None:
"""Test the call_instructor command."""
system_prompt = "Return the user message."
user_prompt = "{'grade': 3}"

actual = await model.call_instructor(
response_model=Response,
response_model=response,
system_prompt=system_prompt,
user_prompt=user_prompt,
)

assert isinstance(actual, Response)
assert isinstance(actual, response)


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_chain_of_density(model: llm.LargeLanguageModel) -> None:
"""Test the chain_of_density command."""
Expand Down Expand Up @@ -108,7 +126,7 @@ async def test_chain_of_density(model: llm.LargeLanguageModel) -> None:
assert len(actual) > 0


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_chain_of_verification_str(model: llm.LargeLanguageModel) -> None:
"""Test the chain_of_verification command."""
Expand All @@ -122,13 +140,15 @@ async def test_chain_of_verification_str(model: llm.LargeLanguageModel) -> None:
)

assert isinstance(actual, str)
assert "horse" in actual.lower()


@pytest.mark.parametrize("model", ["openai", "bedrock"], indirect=True)
@pytest.mark.parametrize("model", LLM_MODELS, indirect=True)
@pytest.mark.asyncio
async def test_chain_of_verification_model(model: llm.LargeLanguageModel) -> None:
"""Test the chain_of_verification command."""
"""Test the chain_of_verification command.

This test may be unstable with Ollama depending on the model used.
"""
text = "Lea is 9 years old. She likes riding horses."

class Response(pydantic.BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
but they are the best we can do without connecting to remote servers on every test.
"""

import json
import types
from unittest import mock

import pydantic
import pytest
import pytest_mock

from cloai.llm import bedrock, openai, utils
from cloai.llm import bedrock, ollama, openai, utils

TEST_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
TEST_SYSTEM_PROMPT = "You are a helpful assistant."
TEST_USER_PROMPT = "What is 2+2?"
TEST_RUN_RESPONSE = "Hello world!"

LLM_TYPE = bedrock.AnthropicBedrockLlm | openai.OpenAiLlm | openai.AzureLlm
llms = ("azure", "anthropic_bedrock", "openai")
LLM_TYPE = (
bedrock.AnthropicBedrockLlm | openai.OpenAiLlm | openai.AzureLlm | ollama.OllamaLlm
)
llms = ("azure", "anthropic_bedrock", "openai", "ollama")


class _TestResponse(pydantic.BaseModel):
Expand Down Expand Up @@ -134,6 +137,17 @@ def openai_llm(
)


@pytest.fixture
def ollama_llm(mocker: pytest_mock.MockerFixture) -> ollama.OllamaLlm:
"""Create the mocked anthropic bedrock llm."""
response = {"message": {"content": TEST_RUN_RESPONSE}}
mocker.patch("ollama.AsyncClient.chat", return_value=response)
return ollama.OllamaLlm(
model=TEST_MODEL,
base_url="somethinglocal",
)


@pytest.fixture
def azure_llm(
mock_azure_client: mock.MagicMock,
Expand All @@ -154,6 +168,7 @@ def llm(
openai_llm: openai.OpenAiLlm,
azure_llm: openai.AzureLlm,
anthropic_bedrock_llm: bedrock.AnthropicBedrockLlm,
ollama_llm: ollama.OllamaLlm,
) -> utils.LlmBaseClass:
"""Create the mocked llm."""
name = request.param
Expand All @@ -163,6 +178,8 @@ def llm(
return anthropic_bedrock_llm
if name == "azure":
return azure_llm
if name == "ollama":
return ollama_llm
raise NotImplementedError


Expand Down Expand Up @@ -201,7 +218,17 @@ async def test_call_instructor_method(
) -> None:
"""Test the call_instructor method."""
expected_response = _TestResponse(answer="4")
llm._instructor.chat.completions.create.return_value = expected_response # type: ignore[call-overload, attr-defined]
if isinstance(llm, ollama.OllamaLlm):
# Ollama doesn't use instructor and therefore requires custom handling.
class Content(pydantic.BaseModel):
content: str = json.dumps({"field": _TestResponse(answer="4").model_dump()})

class Response(pydantic.BaseModel):
message: Content = Content()

llm.client.chat.return_value = Response() # type: ignore[attr-defined]
else:
llm._instructor.chat.completions.create.return_value = expected_response # type: ignore[call-overload, attr-defined]

result = await llm.call_instructor(
_TestResponse,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/llm/test_llm_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_recursive_pydantic_model_dump_primitive() -> None:
assert actual == expected


def test_recursive_pydantic_model_dump_recusive() -> None:
def test_recursive_pydantic_model_dump_recursive() -> None:
"""Test dumping a model containing a model."""
model = ModelRecursive()
expected = model.model_dump()
Expand Down
Loading