Skip to content

Commit

Permalink
Prefix all models with provider for consistency (#593)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Colvin <[email protected]>
  • Loading branch information
sydney-runkle and samuelcolvin authored Jan 7, 2025
1 parent 34ec8a6 commit 7920cb9
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 47 deletions.
6 changes: 5 additions & 1 deletion docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,14 @@ You can then use [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] by name:
```python {title="gemini_model_by_name.py"}
from pydantic_ai import Agent

agent = Agent('gemini-1.5-flash')
agent = Agent('google-gla:gemini-1.5-flash')
...
```

!!! note
The `google-gla` provider prefix represents the [Google **G**enerative **L**anguage **A**PI](https://ai.google.dev/api/all-methods) for `GeminiModel`s.
`google-vertex` is used with [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) for `VertexAIModel`s.

Or initialise the model directly with just the model name:

```python {title="gemini_model_init.py"}
Expand Down
2 changes: 1 addition & 1 deletion examples/pydantic_ai_examples/sql_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class InvalidRequest(BaseModel):

Response: TypeAlias = Union[Success, InvalidRequest]
agent: Agent[Deps, Response] = Agent(
'gemini-1.5-flash',
'google-gla:gemini-1.5-flash',
# Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
result_type=Response, # type: ignore
deps_type=Deps,
Expand Down
2 changes: 1 addition & 1 deletion examples/pydantic_ai_examples/stream_markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# models to try, and the appropriate env var
models: list[tuple[KnownModelName, str]] = [
('gemini-1.5-flash', 'GEMINI_API_KEY'),
('google-gla:gemini-1.5-flash', 'GEMINI_API_KEY'),
('openai:gpt-4o-mini', 'OPENAI_API_KEY'),
('groq:llama-3.1-70b-versatile', 'GROQ_API_KEY'),
]
Expand Down
38 changes: 28 additions & 10 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@
'groq:mixtral-8x7b-32768',
'groq:gemma2-9b-it',
'groq:gemma-7b-it',
'gemini-1.5-flash',
'gemini-1.5-pro',
'gemini-2.0-flash-exp',
'vertexai:gemini-1.5-flash',
'vertexai:gemini-1.5-pro',
# since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
'google-gla:gemini-1.5-flash',
'google-gla:gemini-1.5-pro',
'google-gla:gemini-2.0-flash-exp',
'google-vertex:gemini-1.5-flash',
'google-vertex:gemini-1.5-pro',
'google-vertex:gemini-2.0-flash-exp',
'mistral:mistral-small-latest',
'mistral:mistral-large-latest',
'mistral:codestral-latest',
Expand All @@ -76,9 +75,9 @@
'ollama:qwen2',
'ollama:qwen2.5',
'ollama:starcoder2',
'claude-3-5-haiku-latest',
'claude-3-5-sonnet-latest',
'claude-3-opus-latest',
'anthropic:claude-3-5-haiku-latest',
'anthropic:claude-3-5-sonnet-latest',
'anthropic:claude-3-opus-latest',
'test',
]
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
Expand Down Expand Up @@ -274,6 +273,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .openai import OpenAIModel

return OpenAIModel(model[7:])
elif model.startswith(('gpt', 'o1')):
from .openai import OpenAIModel

return OpenAIModel(model)
elif model.startswith('google-gla'):
from .gemini import GeminiModel

return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
elif model.startswith('gemini'):
from .gemini import GeminiModel

Expand All @@ -283,6 +291,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .groq import GroqModel

return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
elif model.startswith('google-vertex'):
from .vertexai import VertexAIModel

return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
elif model.startswith('vertexai:'):
from .vertexai import VertexAIModel

Expand All @@ -295,6 +308,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .ollama import OllamaModel

return OllamaModel(model[7:])
elif model.startswith('anthropic'):
from .anthropic import AnthropicModel

return AnthropicModel(model[10:])
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
elif model.startswith('claude'):
from .anthropic import AnthropicModel

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def agent_model(
)

def name(self) -> str:
return self.model_name
return f'anthropic:{self.model_name}'

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def agent_model(
)

def name(self) -> str:
return self.model_name
return f'google-gla:{self.model_name}'


class AuthProtocol(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def ainit(self) -> tuple[str, BearerTokenAuth]:
return url, auth

def name(self) -> str:
return f'vertexai:{self.model_name}'
return f'google-vertex:{self.model_name}'


# pyright: reportUnknownMemberType=false
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
assert m.client.api_key == 'foobar'
assert m.name() == 'claude-3-5-haiku-latest'
assert m.name() == 'anthropic:claude-3-5-haiku-latest'


@dataclass
Expand Down
93 changes: 67 additions & 26 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,84 @@
from importlib import import_module

import pytest

from pydantic_ai import UserError
from pydantic_ai.models import infer_model
from pydantic_ai.models.gemini import GeminiModel

from ..conftest import TestEnv, try_import
from ..conftest import TestEnv

TEST_CASES = [
('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
('OPENAI_API_KEY', 'o1', 'openai:o1', 'openai', 'OpenAIModel'),
('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
(
'GEMINI_API_KEY',
'google-vertex:gemini-1.5-flash',
'google-vertex:gemini-1.5-flash',
'vertexai',
'VertexAIModel',
),
(
'GEMINI_API_KEY',
'vertexai:gemini-1.5-flash',
'google-vertex:gemini-1.5-flash',
'vertexai',
'VertexAIModel',
),
(
'ANTHROPIC_API_KEY',
'anthropic:claude-3-5-haiku-latest',
'anthropic:claude-3-5-haiku-latest',
'anthropic',
'AnthropicModel',
),
(
'ANTHROPIC_API_KEY',
'claude-3-5-haiku-latest',
'anthropic:claude-3-5-haiku-latest',
'anthropic',
'AnthropicModel',
),
(
'GROQ_API_KEY',
'groq:llama-3.3-70b-versatile',
'groq:llama-3.3-70b-versatile',
'groq',
'GroqModel',
),
('OLLAMA_API_KEY', 'ollama:llama3', 'ollama:llama3', 'ollama', 'OllamaModel'),
(
'MISTRAL_API_KEY',
'mistral:mistral-small-latest',
'mistral:mistral-small-latest',
'mistral',
'MistralModel',
),
]

with try_import() as openai_imports_successful:
from pydantic_ai.models.openai import OpenAIModel

with try_import() as vertexai_imports_successful:
from pydantic_ai.models.vertexai import VertexAIModel
@pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, module_name, model_class_name', TEST_CASES)
def test_infer_model(
env: TestEnv, mock_api_key: str, model_name: str, expected_model_name: str, module_name: str, model_class_name: str
):
try:
model_module = import_module(f'pydantic_ai.models.{module_name}')
expected_model = getattr(model_module, model_class_name)
except ImportError:
pytest.skip(f'{model_name} dependencies not installed')

env.set(mock_api_key, 'via-env-var')

@pytest.mark.skipif(not openai_imports_successful(), reason='openai not installed')
def test_infer_str_openai(env: TestEnv):
env.set('OPENAI_API_KEY', 'via-env-var')
m = infer_model('openai:gpt-3.5-turbo')
assert isinstance(m, OpenAIModel)
assert m.name() == 'openai:gpt-3.5-turbo'
m = infer_model(model_name) # pyright: ignore[reportArgumentType]
assert isinstance(m, expected_model)
assert m.name() == expected_model_name

m2 = infer_model(m)
assert m2 is m


def test_infer_str_gemini(env: TestEnv):
env.set('GEMINI_API_KEY', 'via-env-var')
m = infer_model('gemini-1.5-flash')
assert isinstance(m, GeminiModel)
assert m.name() == 'gemini-1.5-flash'


@pytest.mark.skipif(not vertexai_imports_successful(), reason='google-auth not installed')
def test_infer_vertexai(env: TestEnv):
m = infer_model('vertexai:gemini-1.5-flash')
assert isinstance(m, VertexAIModel)
assert m.name() == 'vertexai:gemini-1.5-flash'


def test_infer_str_unknown():
with pytest.raises(UserError, match='Unknown model: foobar'):
infer_model('foobar') # pyright: ignore[reportArgumentType]
4 changes: 2 additions & 2 deletions tests/models/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
'publishers/google/models/gemini-1.5-flash:'
)
assert model.auth is not None
assert model.name() == snapshot('vertexai:gemini-1.5-flash')
assert model.name() == snapshot('google-vertex:gemini-1.5-flash')


class NoOpCredentials:
Expand All @@ -67,7 +67,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
'publishers/google/models/gemini-1.5-flash:'
)
assert model.auth is not None
assert model.name() == snapshot('vertexai:gemini-1.5-flash')
assert model.name() == snapshot('google-vertex:gemini-1.5-flash')

await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
assert model.url is not None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,15 +780,15 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:

def test_model_requests_blocked(env: TestEnv, set_event_loop: None):
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)
agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)

with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
agent.run_sync('Hello')


def test_override_model(env: TestEnv, set_event_loop: None):
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)
agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)

with agent.override(model='test'):
result = agent.run_sync('Hello')
Expand Down

0 comments on commit 7920cb9

Please sign in to comment.