Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
josead committed Jan 1, 2025
1 parent ade332e commit 2e63819
Showing 1 changed file with 11 additions and 44 deletions.
55 changes: 11 additions & 44 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import re
import sys
from datetime import timezone
from typing import Any, Callable, Union
Expand Down Expand Up @@ -229,7 +228,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
)


def test_plain_response_then_tuple(set_event_loop: None):
def test_plain_response(set_event_loop: None):
call_index = 0

def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
Expand Down Expand Up @@ -273,42 +272,6 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
),
]
)
assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage]
assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot(
ModelRequest(
parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))]
)
)
assert result.all_messages()[-1] == snapshot(
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
)
]
)
)


def test_result_tool_return_content_str_return(set_event_loop: None):
agent = Agent('test')

result = agent.run_sync('Hello')
assert result.data == 'success (no tool calls)'

msg = re.escape('Cannot set result tool return content when the return type is `str`.')
with pytest.raises(ValueError, match=msg):
result.all_messages(result_tool_return_content='foobar')


def test_result_tool_return_content_no_tool(set_event_loop: None):
agent = Agent('test', result_type=int)

result = agent.run_sync('Hello')
assert result.data == 0
result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage]
with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")):
result.all_messages(result_tool_return_content='foobar')


def test_response_tuple(set_event_loop: None):
Expand Down Expand Up @@ -545,7 +508,6 @@ async def ret_a(x: str) -> str:
],
_new_message_index=4,
data='{"ret_a":"a-apple"}',
_result_tool_name=None,
_usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None),
)
)
Expand Down Expand Up @@ -588,7 +550,6 @@ async def ret_a(x: str) -> str:
],
_new_message_index=4,
data='{"ret_a":"a-apple"}',
_result_tool_name=None,
_usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None),
)
)
Expand Down Expand Up @@ -688,7 +649,6 @@ async def ret_a(x: str) -> str:
),
],
_new_message_index=5,
_result_tool_name='final_result',
_usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None),
)
)
Expand Down Expand Up @@ -1259,9 +1219,12 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
assert messages == []
result = agent.run_sync('Hello')
assert result.data == 'success (no tool calls)'
result2 = agent.run_sync('Hello 2')
assert result2.data == 'success (no tool calls)'

with pytest.raises(UserError) as exc_info:
agent.run_sync('Hello')
assert (
str(exc_info.value)
== 'The capture_run_messages() context manager may only be used to wrap one call to run(), run_sync(), or run_stream().'
)
assert messages == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
Expand Down Expand Up @@ -1420,6 +1383,10 @@ async def func_two_dynamic():
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)


def test_capture_run_messages_tool_agent(set_event_loop: None) -> None:
agent_outer = Agent('test')
Expand Down

0 comments on commit 2e63819

Please sign in to comment.