From 7d9e4877e0a8869676470daf1701e5bb7e8c918d Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 30 Dec 2024 13:53:47 -0500 Subject: [PATCH] Ensure `TestModel` handles result retries correctly (#572) Co-authored-by: Samuel Colvin --- pydantic_ai_slim/pydantic_ai/models/test.py | 25 +++++++++++++++------ tests/models/test_model_test.py | 24 +++++++++++++++++++- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index dddb2e7b..9044077a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -16,6 +16,7 @@ ModelMessage, ModelRequest, ModelResponse, + ModelResponsePart, RetryPromptPart, TextPart, ToolCallPart, @@ -177,13 +178,23 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | # check if there are any retry prompts, if so retry them new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)} if new_retry_names: - return ModelResponse( - parts=[ - ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) - for name, args in self.tool_calls - if name in new_retry_names - ] - ) + # Handle retries for both function tools and result tools + # Check function tools first + retry_parts: list[ModelResponsePart] = [ + ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) + for name, args in self.tool_calls + if name in new_retry_names + ] + # Check result tools + if self.result_tools: + retry_parts.extend( + [ + ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool)) + for tool in self.result_tools + if tool.name in new_retry_names + ] + ) + return ModelResponse(parts=retry_parts) if response_text := self.result.left: if response_text.value is None: diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 64af46f7..ac8edd3b 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -10,7 +10,8 @@ from inline_snapshot import snapshot from pydantic import BaseModel, Field -from pydantic_ai import Agent, ModelRetry +from pydantic_ai import Agent, ModelRetry, RunContext +from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( ModelRequest, ModelResponse, @@ -109,6 +110,27 @@ async def my_ret(x: int) -> str: ) +def test_result_tool_retry_error_handled(set_event_loop: None): + class ResultModel(BaseModel): + x: int + y: str + + agent = Agent('test', result_type=ResultModel, retries=2) + + call_count = 0 + + @agent.result_validator + def validate_result(ctx: RunContext[None], result: ResultModel) -> ResultModel: + nonlocal call_count + call_count += 1 + raise ModelRetry('Fail') + + with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + agent.run_sync('Hello', model=TestModel()) + + assert call_count == 3 + + def test_json_schema_test_data(): class NestedModel(BaseModel): foo: str