From ae82fa4fc9be12b6119869a998a607ac23613bc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Antonio=20Dom=C3=ADnguez?= Date: Tue, 7 Jan 2025 08:25:29 -0300 Subject: [PATCH] Adds `dynamic` to `system_prompt` decorator, allowing reevaluation (#560) Co-authored-by: Samuel Colvin --- docs/message-history.md | 20 ++- docs/tools.md | 1 + .../pydantic_ai/_system_prompt.py | 1 + pydantic_ai_slim/pydantic_ai/agent.py | 78 ++++++++-- pydantic_ai_slim/pydantic_ai/messages.py | 6 + tests/test_agent.py | 144 +++++++++++++++++- 6 files changed, 233 insertions(+), 17 deletions(-) diff --git a/docs/message-history.md b/docs/message-history.md index a8e94209..e730d2c4 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -43,7 +43,9 @@ print(result.all_messages()) ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', part_kind='system-prompt' + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='Tell me a joke.', @@ -85,7 +87,9 @@ async def main(): ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', part_kind='system-prompt' + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='Tell me a joke.', @@ -112,7 +116,9 @@ async def main(): ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', part_kind='system-prompt' + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='Tell me a joke.', @@ -166,7 +172,9 @@ print(result2.all_messages()) ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', part_kind='system-prompt' + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='Tell me a joke.', @@ -238,7 +246,9 @@ print(result2.all_messages()) ModelRequest( parts=[ SystemPromptPart( - content='Be a helpful assistant.', part_kind='system-prompt' + content='Be a helpful assistant.', + dynamic_ref=None, + part_kind='system-prompt', ), UserPromptPart( content='Tell me a joke.', diff --git a/docs/tools.md b/docs/tools.md index ade96bd1..b5e2c004 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -72,6 +72,7 @@ print(dice_result.all_messages()) parts=[ SystemPromptPart( content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.", + dynamic_ref=None, part_kind='system-prompt', ), UserPromptPart( diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index 27e59b8f..4cbc64da 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -12,6 +12,7 @@ @dataclass class SystemPromptRunner(Generic[AgentDeps]): function: SystemPromptFunc[AgentDeps] + dynamic: bool = False _takes_ctx: bool = field(init=False) _is_async: bool = field(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b4af5db7..d044c46b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -107,6 +107,9 @@ class Agent(Generic[AgentDeps, ResultData]): _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False) _default_retries: int = dataclasses.field(repr=False) _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False) + _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field( + repr=False + ) _deps_type: type[AgentDeps] = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False) @@ -182,6 +185,7 @@ def __init__( self._register_tool(Tool(tool)) self._deps_type = deps_type self._system_prompt_functions = [] + self._system_prompt_dynamic_functions = {} self._max_result_retries = result_retries if result_retries is not None else retries self._result_validators = [] @@ -535,17 +539,37 @@ def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ... @overload def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ... + @overload + def system_prompt( + self, /, *, dynamic: bool = False + ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ... + def system_prompt( - self, func: _system_prompt.SystemPromptFunc[AgentDeps], / - ) -> _system_prompt.SystemPromptFunc[AgentDeps]: + self, + func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None, + /, + *, + dynamic: bool = False, + ) -> ( + Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]] + | _system_prompt.SystemPromptFunc[AgentDeps] + ): """Decorator to register a system prompt function. Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument. Can decorate a sync or async functions. + The decorator can be used either bare (`agent.system_prompt`) or as a function call + (`agent.system_prompt(...)`), see the examples below. + Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure the type of the function, see `tests/typed_agent.py` for tests. + Args: + func: The function to decorate + dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided, + see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref] + Example: ```python from pydantic_ai import Agent, RunContext @@ -556,17 +580,27 @@ def system_prompt( def simple_system_prompt() -> str: return 'foobar' - @agent.system_prompt + @agent.system_prompt(dynamic=True) async def async_system_prompt(ctx: RunContext[str]) -> str: return f'{ctx.deps} is the best' - - result = agent.run_sync('foobar', deps='spam') - print(result.data) - #> success (no tool calls) ``` """ - self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func)) - return func + if func is None: + + def decorator( + func_: _system_prompt.SystemPromptFunc[AgentDeps], + ) -> _system_prompt.SystemPromptFunc[AgentDeps]: + runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic) + self._system_prompt_functions.append(runner) + if dynamic: + self._system_prompt_dynamic_functions[func_.__qualname__] = runner + return func_ + + return decorator + else: + assert not dynamic, "dynamic can't be True in this case" + self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic)) + return func @overload def result_validator( @@ -835,6 +869,23 @@ async def add_tool(tool: Tool[AgentDeps]) -> None: result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [], ) + async def _reevaluate_dynamic_prompts( + self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps] + ) -> None: + """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function.""" + # Only proceed if there's at least one dynamic runner. + if self._system_prompt_dynamic_functions: + for msg in messages: + if isinstance(msg, _messages.ModelRequest): + for i, part in enumerate(msg.parts): + if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref: + # Look up the runner by its ref + if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref): + updated_part_content = await runner.run(run_context) + msg.parts[i] = _messages.SystemPromptPart( + updated_part_content, dynamic_ref=part.dynamic_ref + ) + async def _prepare_messages( self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps] ) -> list[_messages.ModelMessage]: @@ -850,8 +901,10 @@ async def _prepare_messages( ctx_messages.used = True if message_history: - # shallow copy messages + # Shallow copy messages messages.extend(message_history) + # Reevaluate any dynamic system prompt parts + await self._reevaluate_dynamic_prompts(messages, run_context) messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)])) else: parts = await self._sys_parts(run_context) @@ -1088,7 +1141,10 @@ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts] for sys_prompt_runner in self._system_prompt_functions: prompt = await sys_prompt_runner.run(run_context) - messages.append(_messages.SystemPromptPart(prompt)) + if sys_prompt_runner.dynamic: + messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__)) + else: + messages.append(_messages.SystemPromptPart(prompt)) return messages def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart: diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index a3cdd25a..3941b7de 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -21,6 +21,12 @@ class SystemPromptPart: content: str """The content of the prompt.""" + dynamic_ref: str | None = None + """The ref of the dynamic system prompt function that generated this part. + + Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information. + """ + part_kind: Literal['system-prompt'] = 'system-prompt' """Part type identifier, this is available on all parts as a discriminator.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index 277eeeb5..e2269ff3 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1260,7 +1260,6 @@ def test_double_capture_run_messages(set_event_loop: None) -> None: assert result.data == 'success (no tool calls)' result2 = agent.run_sync('Hello 2') assert result2.data == 'success (no tool calls)' - assert messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), @@ -1269,6 +1268,149 @@ def test_double_capture_run_messages(set_event_loop: None) -> None: ) +def test_dynamic_false_no_reevaluate(set_event_loop: None): + """When dynamic is false (default), the system prompt is not reevaluated + i.e: SystemPromptPart( + content="A", <--- Remains the same when `message_history` is passed. + part_kind='system-prompt') + """ + agent = Agent('test', system_prompt='Foobar') + + dynamic_value = 'A' + + @agent.system_prompt + async def func() -> str: + return dynamic_value + + res = agent.run_sync('Hello') + + assert res.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart(content=dynamic_value, part_kind='system-prompt'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + ], + kind='request', + ), + ModelResponse( + parts=[TextPart(content='success (no tool calls)', part_kind='text')], + timestamp=IsNow(tz=timezone.utc), + kind='response', + ), + ] + ) + + dynamic_value = 'B' + + res_two = agent.run_sync('World', message_history=res.all_messages()) + + assert res_two.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart( + content='A', # Remains the same + part_kind='system-prompt', + ), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + ], + kind='request', + ), + ModelResponse( + parts=[TextPart(content='success (no tool calls)', part_kind='text')], + timestamp=IsNow(tz=timezone.utc), + kind='response', + ), + ModelRequest( + parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], + kind='request', + ), + ModelResponse( + parts=[TextPart(content='success (no tool calls)', part_kind='text')], + timestamp=IsNow(tz=timezone.utc), + kind='response', + ), + ] + ) + + +def test_dynamic_true_reevaluate_system_prompt(set_event_loop: None): + """When dynamic is true, the system prompt is reevaluated + i.e: SystemPromptPart( + content="B", <--- Updated value + part_kind='system-prompt') + """ + agent = Agent('test', system_prompt='Foobar') + + dynamic_value = 'A' + + @agent.system_prompt(dynamic=True) + async def func(): + return dynamic_value + + res = agent.run_sync('Hello') + + assert res.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart( + content=dynamic_value, + part_kind='system-prompt', + dynamic_ref=func.__qualname__, + ), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + ], + kind='request', + ), + ModelResponse( + parts=[TextPart(content='success (no tool calls)', part_kind='text')], + timestamp=IsNow(tz=timezone.utc), + kind='response', + ), + ] + ) + + dynamic_value = 'B' + + res_two = agent.run_sync('World', message_history=res.all_messages()) + + assert res_two.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar', part_kind='system-prompt'), + SystemPromptPart( + content='B', + part_kind='system-prompt', + dynamic_ref=func.__qualname__, + ), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'), + ], + kind='request', + ), + ModelResponse( + parts=[TextPart(content='success (no tool calls)', part_kind='text')], + timestamp=IsNow(tz=timezone.utc), + kind='response', + ), + ModelRequest( + parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')], + kind='request', + ), + ModelResponse( + parts=[TextPart(content='success (no tool calls)', part_kind='text')], + timestamp=IsNow(tz=timezone.utc), + kind='response', + ), + ] + ) + + def test_capture_run_messages_tool_agent(set_event_loop: None) -> None: agent_outer = Agent('test') agent_inner = Agent(TestModel(custom_result_text='inner agent result'))