Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove PlainResponseForbidden #12

Merged
merged 1 commit into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,12 @@ async def _handle_model_response(
if self._allow_text_result:
return _utils.Either(left=cast(result.ResultData, model_response.content))
else:
return _utils.Either(right=[_messages.PlainResponseForbidden()])
self._incr_result_retry()
assert self._result_tool is not None
response = _messages.UserPrompt(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
return _utils.Either(right=[response])
elif model_response.role == 'llm-tool-calls':
if self._result_tool is not None:
# if there's a result schema, and any of the calls match that name, return the result
Expand Down
13 changes: 1 addition & 12 deletions pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,6 @@ def llm_response(self) -> str:
return f'{description}\n\nFix the errors and try again.'


@dataclass
class PlainResponseForbidden:
# TODO remove and replace with ToolRetry
timestamp: datetime = field(default_factory=datetime.now)
role: Literal['plain-response-forbidden'] = 'plain-response-forbidden'

@staticmethod
def llm_response() -> str:
return 'Plain text responses are not allowed, please call one of the functions instead.'


@dataclass
class LLMResponse:
content: str
Expand Down Expand Up @@ -105,6 +94,6 @@ class LLMToolCalls:


LLMMessage = Union[LLMResponse, LLMToolCalls]
Message = Union[SystemPrompt, UserPrompt, ToolReturn, ToolRetry, PlainResponseForbidden, LLMMessage]
Message = Union[SystemPrompt, UserPrompt, ToolReturn, ToolRetry, LLMMessage]

MessagesTypeAdapter = pydantic.TypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
2 changes: 0 additions & 2 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,6 @@ def message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiConte
elif m.role == 'tool-retry':
# ToolRetry ->
return _utils.Either(right=_GeminiContent.function_retry(m))
elif m.role == 'plain-response-forbidden':
return _utils.Either(right=_GeminiContent.user_text(m.llm_response()))
elif m.role == 'llm-response':
# LLMResponse ->
return _utils.Either(right=_GeminiContent.model_text(m.content))
Expand Down
6 changes: 0 additions & 6 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,6 @@ def map_message(message: Message) -> chat.ChatCompletionMessageParam:
role='assistant',
tool_calls=[_map_tool_call(t) for t in message.calls],
)
elif message.role == 'plain-response-forbidden':
# PlainResponseForbidden ->
return chat.ChatCompletionUserMessageParam(
role='user',
content=message.llm_response(),
)
else:
assert_never(message)

Expand Down
46 changes: 45 additions & 1 deletion tests/test_result_validation.py → tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
from pydantic import BaseModel

from pydantic_ai import Agent, ModelRetry
from pydantic_ai.messages import LLMMessage, LLMToolCalls, Message, ToolCall, ToolRetry, UserPrompt
from pydantic_ai.messages import (
ArgsJson,
LLMMessage,
LLMResponse,
LLMToolCalls,
Message,
ToolCall,
ToolRetry,
UserPrompt,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from tests.conftest import IsNow

Expand Down Expand Up @@ -107,3 +116,38 @@ def validate_result(r: Foo) -> Foo:
LLMToolCalls(calls=[ToolCall.from_json('final_result', '{"a": 42, "b": "foo"}')], timestamp=IsNow()),
]
)


def test_plain_response():
call_index = 0

def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage:
nonlocal call_index

assert info.result_tool is not None
call_index += 1
if call_index == 1:
return LLMResponse(content='hello')
else:
args_json = '{"response": ["foo", "bar"]}'
return LLMToolCalls(calls=[ToolCall.from_json(info.result_tool.name, args_json)])

agent = Agent(FunctionModel(return_tuple), deps=None, result_type=tuple[str, str])

result = agent.run_sync('Hello')
assert result.response == ('foo', 'bar')
assert call_index == 2
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
LLMResponse(content='hello', timestamp=IsNow()),
UserPrompt(
content='Plain text responses are not permitted, please call one of the functions instead.',
timestamp=IsNow(),
),
LLMToolCalls(
calls=[ToolCall(tool_name='final_result', args=ArgsJson(args_json='{"response": ["foo", "bar"]}'))],
timestamp=IsNow(),
),
]
)
Loading