Skip to content

Commit

Permalink
Ensure TestModel handles result retries correctly (#572)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Colvin <[email protected]>
  • Loading branch information
jlowin and samuelcolvin authored Dec 30, 2024
1 parent d94931e commit 7d9e487
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
25 changes: 18 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponsePart,
RetryPromptPart,
TextPart,
ToolCallPart,
Expand Down Expand Up @@ -177,13 +178,23 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
# check if there are any retry prompts, if so retry them
new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
if new_retry_names:
return ModelResponse(
parts=[
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
for name, args in self.tool_calls
if name in new_retry_names
]
)
# Handle retries for both function tools and result tools
# Check function tools first
retry_parts: list[ModelResponsePart] = [
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
for name, args in self.tool_calls
if name in new_retry_names
]
# Check result tools
if self.result_tools:
retry_parts.extend(
[
ToolCallPart.from_raw_args(tool.name, self.gen_tool_args(tool))
for tool in self.result_tools
if tool.name in new_retry_names
]
)
return ModelResponse(parts=retry_parts)

if response_text := self.result.left:
if response_text.value is None:
Expand Down
24 changes: 23 additions & 1 deletion tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from inline_snapshot import snapshot
from pydantic import BaseModel, Field

from pydantic_ai import Agent, ModelRetry
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -109,6 +110,27 @@ async def my_ret(x: int) -> str:
)


def test_result_tool_retry_error_handled(set_event_loop: None):
class ResultModel(BaseModel):
x: int
y: str

agent = Agent('test', result_type=ResultModel, retries=2)

call_count = 0

@agent.result_validator
def validate_result(ctx: RunContext[None], result: ResultModel) -> ResultModel:
nonlocal call_count
call_count += 1
raise ModelRetry('Fail')

with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'):
agent.run_sync('Hello', model=TestModel())

assert call_count == 3


def test_json_schema_test_data():
class NestedModel(BaseModel):
foo: str
Expand Down

0 comments on commit 7d9e487

Please sign in to comment.