Skip to content

Commit

Permalink
Make capture_run_messages support nested agent calls (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Dec 30, 2024
1 parent fde6c9a commit d94931e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 35 deletions.
4 changes: 1 addition & 3 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,4 @@ with capture_run_messages() as messages: # (2)!
_(This example is complete, it can be run "as is")_

!!! note
You may not call [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] more than once within a single `capture_run_messages` context.

If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
If you call [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only.
57 changes: 31 additions & 26 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from types import FrameType
from typing import Any, Callable, Generic, Literal, cast, final, overload

Expand Down Expand Up @@ -60,7 +59,7 @@


@final
@dataclass(init=False)
@dataclasses.dataclass(init=False)
class Agent(Generic[AgentDeps, ResultData]):
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
Expand Down Expand Up @@ -100,17 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]):
be merged with this value, with the runtime argument taking priority.
"""

_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
_allow_text_result: bool = field(repr=False)
_system_prompts: tuple[str, ...] = field(repr=False)
_function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
_default_retries: int = field(repr=False)
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
_deps_type: type[AgentDeps] = field(repr=False)
_max_result_retries: int = field(repr=False)
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
_result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
_allow_text_result: bool = dataclasses.field(repr=False)
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
_default_retries: int = dataclasses.field(repr=False)
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)

def __init__(
self,
Expand Down Expand Up @@ -836,15 +835,15 @@ async def _prepare_messages(
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
) -> list[_messages.ModelMessage]:
try:
messages = _messages_ctx_var.get()
ctx_messages = _messages_ctx_var.get()
except LookupError:
messages = []
messages: list[_messages.ModelMessage] = []
else:
if messages:
raise exceptions.UserError(
'The capture_run_messages() context manager may only be used to wrap '
'one call to run(), run_sync(), or run_stream().'
)
if ctx_messages.used:
messages = []
else:
messages = ctx_messages.messages
ctx_messages.used = True

if message_history:
# shallow copy messages
Expand Down Expand Up @@ -1138,7 +1137,13 @@ def last_run_messages(self) -> list[_messages.ModelMessage]:
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')


_messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var')
@dataclasses.dataclass
class _RunMessages:
messages: list[_messages.ModelMessage]
used: bool = False


_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')


@contextmanager
Expand All @@ -1162,21 +1167,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
```
!!! note
You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context.
If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
`messages` will represent the messages exchanged during the first call only.
"""
try:
yield _messages_ctx_var.get()
yield _messages_ctx_var.get().messages
except LookupError:
messages: list[_messages.ModelMessage] = []
token = _messages_ctx_var.set(messages)
token = _messages_ctx_var.set(_RunMessages(messages))
try:
yield messages
finally:
_messages_ctx_var.reset(token)


@dataclass
@dataclasses.dataclass
class _MarkFinalResult(Generic[ResultData]):
"""Marker class to indicate that the result is the final result.
Expand Down
41 changes: 35 additions & 6 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,15 +1218,44 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
assert messages == []
result = agent.run_sync('Hello')
assert result.data == 'success (no tool calls)'
with pytest.raises(UserError) as exc_info:
agent.run_sync('Hello')
assert (
str(exc_info.value)
== 'The capture_run_messages() context manager may only be used to wrap one call to run(), run_sync(), or run_stream().'
)
result2 = agent.run_sync('Hello 2')
assert result2.data == 'success (no tool calls)'

assert messages == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(parts=[TextPart(content='success (no tool calls)')], timestamp=IsNow(tz=timezone.utc)),
]
)


def test_capture_run_messages_tool_agent(set_event_loop: None) -> None:
agent_outer = Agent('test')
agent_inner = Agent(TestModel(custom_result_text='inner agent result'))

@agent_outer.tool_plain
async def foobar(x: str) -> str:
result_ = await agent_inner.run(x)
return result_.data

with capture_run_messages() as messages:
result = agent_outer.run_sync('foobar')

assert result.data == snapshot('{"foobar":"inner agent result"}')
assert messages == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
parts=[
ToolReturnPart(tool_name='foobar', content='inner agent result', timestamp=IsNow(tz=timezone.utc))
]
),
ModelResponse(
parts=[TextPart(content='{"foobar":"inner agent result"}')], timestamp=IsNow(tz=timezone.utc)
),
]
)

0 comments on commit d94931e

Please sign in to comment.