diff --git a/tests/conftest.py b/tests/conftest.py index 2347f6088..4a219ce00 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ __all__ = 'IsNow', 'TestEnv', 'ClientWithHandler', 'try_import' + pydantic_ai.models.ALLOW_MODEL_REQUESTS = False if TYPE_CHECKING: diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index d065f88ef..66a16605a 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -16,7 +16,6 @@ ModelResponse, RetryPromptPart, SystemPromptPart, - TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -103,15 +102,9 @@ async def test_sync_request_text_response(allow_model_requests: None): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)), ] ) @@ -245,9 +238,6 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 56ad61018..ef7c5e6a6 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -445,10 +445,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='Hello world')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)), ] ) assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) @@ -458,15 +455,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='Hello world')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='Hello world')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)), ] ) @@ -589,10 +580,7 @@ async def get_location(loc_name: str) -> str: ), ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9)) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 4ce2c8510..93a7cf637 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -18,7 +18,6 @@ ModelResponse, RetryPromptPart, SystemPromptPart, - TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -140,15 +139,9 @@ async def test_request_simple_success(allow_model_requests: None): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) @@ -314,10 +307,7 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)), ] ) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index a1f6927ab..8dc793634 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -22,7 +22,6 @@ ModelResponse, RetryPromptPart, SystemPromptPart, - TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -241,15 +240,9 @@ async def test_multiple_completions(allow_model_requests: None): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='hello again')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='hello again', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) @@ -291,20 +284,11 @@ async def test_three_completions(allow_model_requests: None): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='hello again')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='hello again', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='final message')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='final message', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) @@ -1161,9 +1145,8 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ModelResponse.from_text( + content='final response', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) ), ] ) @@ -1521,10 +1504,7 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) @@ -1645,10 +1625,7 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 70173590e..dec5c2fa7 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -12,8 +12,6 @@ from pydantic_ai import Agent, ModelRetry, RunContext from pydantic_ai.messages import ( - ArgsDict, - ArgsJson, ModelMessage, ModelRequest, ModelResponse, @@ -129,11 +127,7 @@ def test_weather(set_event_loop: None): [ ModelRequest(parts=[UserPromptPart(content='London', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( - parts=[ - ToolCallPart( - tool_name='get_location', args=ArgsJson(args_json='{"location_description": "London"}') - ) - ], + parts=[ToolCallPart.from_raw_args('get_location', '{"location_description": "London"}')], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -144,14 +138,19 @@ def test_weather(set_event_loop: None): ] ), ModelResponse( - parts=[ToolCallPart(tool_name='get_weather', args=ArgsJson(args_json='{"lat": 51, "lng": 0}'))], + parts=[ + ToolCallPart.from_raw_args( + 'get_weather', + '{"lat": 51, "lng": 0}', + ) + ], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( parts=[ToolReturnPart(tool_name='get_weather', content='Raining', timestamp=IsNow(tz=timezone.utc))] ), - ModelResponse( - parts=[TextPart(content='Raining in London')], + ModelResponse.from_text( + content='Raining in London', timestamp=IsNow(tz=timezone.utc), ), ] @@ -314,11 +313,11 @@ def test_call_all(set_event_loop: None): ), ModelResponse( parts=[ - ToolCallPart(tool_name='foo', args=ArgsDict(args_dict={'x': 0})), - ToolCallPart(tool_name='bar', args=ArgsDict(args_dict={'x': 0})), - ToolCallPart(tool_name='baz', args=ArgsDict(args_dict={'x': 0})), - ToolCallPart(tool_name='qux', args=ArgsDict(args_dict={'x': 0})), - ToolCallPart(tool_name='quz', args=ArgsDict(args_dict={'x': 'a'})), + ToolCallPart.from_raw_args('foo', {'x': 0}), + ToolCallPart.from_raw_args('bar', {'x': 0}), + ToolCallPart.from_raw_args('baz', {'x': 0}), + ToolCallPart.from_raw_args('qux', {'x': 0}), + ToolCallPart.from_raw_args('quz', {'x': 'a'}), ], timestamp=IsNow(tz=timezone.utc), ), @@ -331,9 +330,8 @@ def test_call_all(set_event_loop: None): ToolReturnPart(tool_name='quz', content='a', timestamp=IsNow(tz=timezone.utc)), ] ), - ModelResponse( - parts=[TextPart(content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')], - timestamp=IsNow(tz=timezone.utc), + ModelResponse.from_text( + content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}', timestamp=IsNow(tz=timezone.utc) ), ] ) @@ -398,10 +396,7 @@ async def test_stream_text(): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='hello world')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='hello world', timestamp=IsNow(tz=timezone.utc)), ] ) assert result.usage() == snapshot(Usage(requests=1)) diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 8b5c7893c..ac8edd3bc 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -13,11 +13,9 @@ from pydantic_ai import Agent, ModelRetry, RunContext from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( - ArgsDict, ModelRequest, ModelResponse, RetryPromptPart, - TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -97,7 +95,7 @@ async def my_ret(x: int) -> str: [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( - parts=[ToolCallPart(tool_name='my_ret', args=ArgsDict(args_dict={'x': 0}))], + parts=[ToolCallPart.from_raw_args('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -105,15 +103,9 @@ async def my_ret(x: int) -> str: RetryPromptPart(content='First call failed', tool_name='my_ret', timestamp=IsNow(tz=timezone.utc)) ] ), - ModelResponse( - parts=[ToolCallPart(tool_name='my_ret', args=ArgsDict(args_dict={'x': 0}))], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse(parts=[ToolCallPart.from_raw_args('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc)), ModelRequest(parts=[ToolReturnPart(tool_name='my_ret', content='1', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='{"my_ret":"1"}')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='{"my_ret":"1"}', timestamp=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py index 635905aa7..c608c656f 100644 --- a/tests/models/test_ollama.py +++ b/tests/models/test_ollama.py @@ -9,7 +9,6 @@ from pydantic_ai.messages import ( ModelRequest, ModelResponse, - TextPart, UserPromptPart, ) from pydantic_ai.result import Usage @@ -56,14 +55,8 @@ async def test_request_simple_success(allow_model_requests: None): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 15434970b..722f51f78 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -18,7 +18,6 @@ ModelResponse, RetryPromptPart, SystemPromptPart, - TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -149,15 +148,9 @@ async def test_request_simple_success(allow_model_requests: None): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='world')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) @@ -326,10 +319,7 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), + ModelResponse.from_text(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)), ] ) assert result.usage() == snapshot( diff --git a/tests/test_agent.py b/tests/test_agent.py index 0c7c22b9f..ed03e1b78 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -88,7 +88,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"a": "wrong", "b": "foo"}'))], + parts=[ToolCallPart.from_raw_args('final_result', '{"a": "wrong", "b": "foo"}')], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -108,7 +108,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse ] ), ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"a": 42, "b": "foo"}'))], + parts=[ToolCallPart.from_raw_args('final_result', '{"a": 42, "b": "foo"}')], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -203,7 +203,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo: [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"a": 41, "b": "foo"}'))], + parts=[ToolCallPart.from_raw_args('final_result', '{"a": 41, "b": "foo"}')], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -214,7 +214,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo: ] ), ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"a": 42, "b": "foo"}'))], + parts=[ToolCallPart.from_raw_args('final_result', '{"a": 42, "b": "foo"}')], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -250,10 +250,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='hello')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='hello', timestamp=IsNow(tz=timezone.utc)), ModelRequest( parts=[ RetryPromptPart( @@ -776,10 +773,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ) ] ), - ModelResponse( - parts=[TextPart(content='success')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='success', timestamp=IsNow(tz=timezone.utc)), ] ) @@ -1225,10 +1219,7 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[TextPart(content='final response')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 678afee11..0e5dbc3f3 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -16,7 +16,6 @@ ModelRequest, ModelResponse, RetryPromptPart, - TextPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -86,10 +85,7 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] ), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], - timestamp=IsNow(tz=timezone.utc), - ), + ModelResponse.from_text(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), ] ) diff --git a/uv.lock b/uv.lock index db04c5278..fbc5cabc7 100644 --- a/uv.lock +++ b/uv.lock @@ -923,8 +923,8 @@ wheels = [ [[package]] name = "inline-snapshot" -version = "0.19.0" -source = { git = "https://github.com/15r10nk/inline-snapshot?rev=pydantic_ai_fixes#793ad2f70be7b3508781f702a6b4afcd2d17dfdc" } +version = "0.19.1" +source = { git = "https://github.com/15r10nk/inline-snapshot?rev=fix_pydantic_ai_errors#22b351f7a56d65c00026a10f427472397e6f3d28" } dependencies = [ { name = "asttokens" }, { name = "executing" },