Skip to content

Commit

Permalink
Fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
josead committed Jan 1, 2025
1 parent 49bf36f commit e8f798e
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
import sys
from datetime import timezone
from typing import Any, Callable, Union
Expand Down Expand Up @@ -228,7 +229,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
)


def test_plain_response(set_event_loop: None):
def test_plain_response_then_tuple(set_event_loop: None):
call_index = 0

def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
Expand Down Expand Up @@ -272,6 +273,42 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
),
]
)
assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage]
assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot(
ModelRequest(
parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))]
)
)
assert result.all_messages()[-1] == snapshot(
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
)
]
)
)


def test_result_tool_return_content_str_return(set_event_loop: None):
agent = Agent('test')

result = agent.run_sync('Hello')
assert result.data == 'success (no tool calls)'

msg = re.escape('Cannot set result tool return content when the return type is `str`.')
with pytest.raises(ValueError, match=msg):
result.all_messages(result_tool_return_content='foobar')


def test_result_tool_return_content_no_tool(set_event_loop: None):
agent = Agent('test', result_type=int)

result = agent.run_sync('Hello')
assert result.data == 0
result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage]
with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")):
result.all_messages(result_tool_return_content='foobar')


def test_response_tuple(set_event_loop: None):
Expand Down Expand Up @@ -508,6 +545,7 @@ async def ret_a(x: str) -> str:
],
_new_message_index=4,
data='{"ret_a":"a-apple"}',
_result_tool_name=None,
_usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None),
)
)
Expand Down Expand Up @@ -550,6 +588,7 @@ async def ret_a(x: str) -> str:
],
_new_message_index=4,
data='{"ret_a":"a-apple"}',
_result_tool_name=None,
_usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None),
)
)
Expand Down Expand Up @@ -649,6 +688,7 @@ async def ret_a(x: str) -> str:
),
],
_new_message_index=5,
_result_tool_name='final_result',
_usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None),
)
)
Expand Down Expand Up @@ -1219,12 +1259,8 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
assert messages == []
result = agent.run_sync('Hello')
assert result.data == 'success (no tool calls)'
with pytest.raises(UserError) as exc_info:
agent.run_sync('Hello')
assert (
str(exc_info.value)
== 'The capture_run_messages() context manager may only be used to wrap one call to run(), run_sync(), or run_stream().'
)
result2 = agent.run_sync('Hello 2')
assert result2.data == 'success (no tool calls)'
assert messages == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
Expand Down

0 comments on commit e8f798e

Please sign in to comment.