diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 7e9a2eac..4e3901d2 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -248,7 +248,12 @@ async def _handle_model_response( if self._allow_text_result: return _utils.Either(left=cast(result.ResultData, model_response.content)) else: - return _utils.Either(right=[_messages.PlainResponseForbidden()]) + self._incr_result_retry() + assert self._result_tool is not None + response = _messages.UserPrompt( + content='Plain text responses are not permitted, please call one of the functions instead.', + ) + return _utils.Either(right=[response]) elif model_response.role == 'llm-tool-calls': if self._result_tool is not None: # if there's a result schema, and any of the calls match that name, return the result diff --git a/pydantic_ai/messages.py b/pydantic_ai/messages.py index 2401c777..5eae6b39 100644 --- a/pydantic_ai/messages.py +++ b/pydantic_ai/messages.py @@ -50,17 +50,6 @@ def llm_response(self) -> str: return f'{description}\n\nFix the errors and try again.' -@dataclass -class PlainResponseForbidden: - # TODO remove and replace with ToolRetry - timestamp: datetime = field(default_factory=datetime.now) - role: Literal['plain-response-forbidden'] = 'plain-response-forbidden' - - @staticmethod - def llm_response() -> str: - return 'Plain text responses are not allowed, please call one of the functions instead.' - - @dataclass class LLMResponse: content: str @@ -105,6 +94,6 @@ class LLMToolCalls: LLMMessage = Union[LLMResponse, LLMToolCalls] -Message = Union[SystemPrompt, UserPrompt, ToolReturn, ToolRetry, PlainResponseForbidden, LLMMessage] +Message = Union[SystemPrompt, UserPrompt, ToolReturn, ToolRetry, LLMMessage] MessagesTypeAdapter = pydantic.TypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]]) diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index 3c591be8..ebb09432 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -176,8 +176,6 @@ def message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiConte elif m.role == 'tool-retry': # ToolRetry -> return _utils.Either(right=_GeminiContent.function_retry(m)) - elif m.role == 'plain-response-forbidden': - return _utils.Either(right=_GeminiContent.user_text(m.llm_response())) elif m.role == 'llm-response': # LLMResponse -> return _utils.Either(right=_GeminiContent.model_text(m.content)) diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 2700984a..9f27b476 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -144,12 +144,6 @@ def map_message(message: Message) -> chat.ChatCompletionMessageParam: role='assistant', tool_calls=[_map_tool_call(t) for t in message.calls], ) - elif message.role == 'plain-response-forbidden': - # PlainResponseForbidden -> - return chat.ChatCompletionUserMessageParam( - role='user', - content=message.llm_response(), - ) else: assert_never(message) diff --git a/tests/test_result_validation.py b/tests/test_agent.py similarity index 74% rename from tests/test_result_validation.py rename to tests/test_agent.py index d0010d30..6ce31ec9 100644 --- a/tests/test_result_validation.py +++ b/tests/test_agent.py @@ -2,7 +2,16 @@ from pydantic import BaseModel from pydantic_ai import Agent, ModelRetry -from pydantic_ai.messages import LLMMessage, LLMToolCalls, Message, ToolCall, ToolRetry, UserPrompt +from pydantic_ai.messages import ( + ArgsJson, + LLMMessage, + LLMResponse, + LLMToolCalls, + Message, + ToolCall, + ToolRetry, + UserPrompt, +) from pydantic_ai.models.function import AgentInfo, FunctionModel from tests.conftest import IsNow @@ -107,3 +116,38 @@ def validate_result(r: Foo) -> Foo: LLMToolCalls(calls=[ToolCall.from_json('final_result', '{"a": 42, "b": "foo"}')], timestamp=IsNow()), ] ) + + +def test_plain_response(): + call_index = 0 + + def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage: + nonlocal call_index + + assert info.result_tool is not None + call_index += 1 + if call_index == 1: + return LLMResponse(content='hello') + else: + args_json = '{"response": ["foo", "bar"]}' + return LLMToolCalls(calls=[ToolCall.from_json(info.result_tool.name, args_json)]) + + agent = Agent(FunctionModel(return_tuple), deps=None, result_type=tuple[str, str]) + + result = agent.run_sync('Hello') + assert result.response == ('foo', 'bar') + assert call_index == 2 + assert result.message_history == snapshot( + [ + UserPrompt(content='Hello', timestamp=IsNow()), + LLMResponse(content='hello', timestamp=IsNow()), + UserPrompt( + content='Plain text responses are not permitted, please call one of the functions instead.', + timestamp=IsNow(), + ), + LLMToolCalls( + calls=[ToolCall(tool_name='final_result', args=ArgsJson(args_json='{"response": ["foo", "bar"]}'))], + timestamp=IsNow(), + ), + ] + )