Skip to content

Commit

Permalink
test: test fix for constructor methods
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jan 15, 2025
1 parent 87a61b0 commit 4b9c1f0
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 152 deletions.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

__all__ = 'IsNow', 'TestEnv', 'ClientWithHandler', 'try_import'


pydantic_ai.models.ALLOW_MODEL_REQUESTS = False

if TYPE_CHECKING:
Expand Down
16 changes: 3 additions & 13 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
Expand Down Expand Up @@ -103,15 +102,9 @@ async def test_sync_request_text_response(allow_model_requests: None):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)),
]
)

Expand Down Expand Up @@ -245,9 +238,6 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[TextPart(content='final response')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)
20 changes: 4 additions & 16 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='Hello world')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
]
)
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
Expand All @@ -458,15 +455,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='Hello world')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='Hello world')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
]
)

Expand Down Expand Up @@ -589,10 +580,7 @@ async def get_location(loc_name: str) -> str:
),
]
),
ModelResponse(
parts=[TextPart(content='final response')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)
assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9))
Expand Down
16 changes: 3 additions & 13 deletions tests/models/test_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
Expand Down Expand Up @@ -140,15 +139,9 @@ async def test_request_simple_success(allow_model_requests: None):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)

Expand Down Expand Up @@ -314,10 +307,7 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[TextPart(content='final response')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)),
]
)

Expand Down
41 changes: 9 additions & 32 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
Expand Down Expand Up @@ -241,15 +240,9 @@ async def test_multiple_completions(allow_model_requests: None):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='hello again')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='hello again', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)

Expand Down Expand Up @@ -291,20 +284,11 @@ async def test_three_completions(allow_model_requests: None):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='hello again')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='hello again', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='final message')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='final message', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)

Expand Down Expand Up @@ -1161,9 +1145,8 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[TextPart(content='final response')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
ModelResponse.from_text(
content='final response', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)
),
]
)
Expand Down Expand Up @@ -1521,10 +1504,7 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[TextPart(content='final response')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)

Expand Down Expand Up @@ -1645,10 +1625,7 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[TextPart(content='final response')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)

Expand Down
39 changes: 17 additions & 22 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.messages import (
ArgsDict,
ArgsJson,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -129,11 +127,7 @@ def test_weather(set_event_loop: None):
[
ModelRequest(parts=[UserPromptPart(content='London', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_location', args=ArgsJson(args_json='{"location_description": "London"}')
)
],
parts=[ToolCallPart.from_raw_args('get_location', '{"location_description": "London"}')],
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
Expand All @@ -144,14 +138,19 @@ def test_weather(set_event_loop: None):
]
),
ModelResponse(
parts=[ToolCallPart(tool_name='get_weather', args=ArgsJson(args_json='{"lat": 51, "lng": 0}'))],
parts=[
ToolCallPart.from_raw_args(
'get_weather',
'{"lat": 51, "lng": 0}',
)
],
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
parts=[ToolReturnPart(tool_name='get_weather', content='Raining', timestamp=IsNow(tz=timezone.utc))]
),
ModelResponse(
parts=[TextPart(content='Raining in London')],
ModelResponse.from_text(
content='Raining in London',
timestamp=IsNow(tz=timezone.utc),
),
]
Expand Down Expand Up @@ -314,11 +313,11 @@ def test_call_all(set_event_loop: None):
),
ModelResponse(
parts=[
ToolCallPart(tool_name='foo', args=ArgsDict(args_dict={'x': 0})),
ToolCallPart(tool_name='bar', args=ArgsDict(args_dict={'x': 0})),
ToolCallPart(tool_name='baz', args=ArgsDict(args_dict={'x': 0})),
ToolCallPart(tool_name='qux', args=ArgsDict(args_dict={'x': 0})),
ToolCallPart(tool_name='quz', args=ArgsDict(args_dict={'x': 'a'})),
ToolCallPart.from_raw_args('foo', {'x': 0}),
ToolCallPart.from_raw_args('bar', {'x': 0}),
ToolCallPart.from_raw_args('baz', {'x': 0}),
ToolCallPart.from_raw_args('qux', {'x': 0}),
ToolCallPart.from_raw_args('quz', {'x': 'a'}),
],
timestamp=IsNow(tz=timezone.utc),
),
Expand All @@ -331,9 +330,8 @@ def test_call_all(set_event_loop: None):
ToolReturnPart(tool_name='quz', content='a', timestamp=IsNow(tz=timezone.utc)),
]
),
ModelResponse(
parts=[TextPart(content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')],
timestamp=IsNow(tz=timezone.utc),
ModelResponse.from_text(
content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}', timestamp=IsNow(tz=timezone.utc)
),
]
)
Expand Down Expand Up @@ -398,10 +396,7 @@ async def test_stream_text():
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='hello world')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='hello world', timestamp=IsNow(tz=timezone.utc)),
]
)
assert result.usage() == snapshot(Usage(requests=1))
Expand Down
14 changes: 3 additions & 11 deletions tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
ArgsDict,
ModelRequest,
ModelResponse,
RetryPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
Expand Down Expand Up @@ -97,23 +95,17 @@ async def my_ret(x: int) -> str:
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='my_ret', args=ArgsDict(args_dict={'x': 0}))],
parts=[ToolCallPart.from_raw_args('my_ret', {'x': 0})],
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
parts=[
RetryPromptPart(content='First call failed', tool_name='my_ret', timestamp=IsNow(tz=timezone.utc))
]
),
ModelResponse(
parts=[ToolCallPart(tool_name='my_ret', args=ArgsDict(args_dict={'x': 0}))],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse(parts=[ToolCallPart.from_raw_args('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc)),
ModelRequest(parts=[ToolReturnPart(tool_name='my_ret', content='1', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='{"my_ret":"1"}')],
timestamp=IsNow(tz=timezone.utc),
),
ModelResponse.from_text(content='{"my_ret":"1"}', timestamp=IsNow(tz=timezone.utc)),
]
)

Expand Down
11 changes: 2 additions & 9 deletions tests/models/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pydantic_ai.messages import (
ModelRequest,
ModelResponse,
TextPart,
UserPromptPart,
)
from pydantic_ai.result import Usage
Expand Down Expand Up @@ -56,14 +55,8 @@ async def test_request_simple_success(allow_model_requests: None):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)
16 changes: 3 additions & 13 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
Expand Down Expand Up @@ -149,15 +148,9 @@ async def test_request_simple_success(allow_model_requests: None):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='world')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)

Expand Down Expand Up @@ -326,10 +319,7 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[TextPart(content='final response')],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelResponse.from_text(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)),
]
)
assert result.usage() == snapshot(
Expand Down
Loading

0 comments on commit 4b9c1f0

Please sign in to comment.