Skip to content

Commit

Permalink
Adds dynamic to system_prompt decorator, allowing reevaluation (#560
Browse files Browse the repository at this point in the history
)

Co-authored-by: Samuel Colvin <[email protected]>
  • Loading branch information
josead and samuelcolvin authored Jan 7, 2025
1 parent 421ed97 commit ae82fa4
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 17 deletions.
20 changes: 15 additions & 5 deletions docs/message-history.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ print(result.all_messages())
ModelRequest(
parts=[
SystemPromptPart(
content='Be a helpful assistant.', part_kind='system-prompt'
content='Be a helpful assistant.',
dynamic_ref=None,
part_kind='system-prompt',
),
UserPromptPart(
content='Tell me a joke.',
Expand Down Expand Up @@ -85,7 +87,9 @@ async def main():
ModelRequest(
parts=[
SystemPromptPart(
content='Be a helpful assistant.', part_kind='system-prompt'
content='Be a helpful assistant.',
dynamic_ref=None,
part_kind='system-prompt',
),
UserPromptPart(
content='Tell me a joke.',
Expand All @@ -112,7 +116,9 @@ async def main():
ModelRequest(
parts=[
SystemPromptPart(
content='Be a helpful assistant.', part_kind='system-prompt'
content='Be a helpful assistant.',
dynamic_ref=None,
part_kind='system-prompt',
),
UserPromptPart(
content='Tell me a joke.',
Expand Down Expand Up @@ -166,7 +172,9 @@ print(result2.all_messages())
ModelRequest(
parts=[
SystemPromptPart(
content='Be a helpful assistant.', part_kind='system-prompt'
content='Be a helpful assistant.',
dynamic_ref=None,
part_kind='system-prompt',
),
UserPromptPart(
content='Tell me a joke.',
Expand Down Expand Up @@ -238,7 +246,9 @@ print(result2.all_messages())
ModelRequest(
parts=[
SystemPromptPart(
content='Be a helpful assistant.', part_kind='system-prompt'
content='Be a helpful assistant.',
dynamic_ref=None,
part_kind='system-prompt',
),
UserPromptPart(
content='Tell me a joke.',
Expand Down
1 change: 1 addition & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ print(dice_result.all_messages())
parts=[
SystemPromptPart(
content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.",
dynamic_ref=None,
part_kind='system-prompt',
),
UserPromptPart(
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
@dataclass
class SystemPromptRunner(Generic[AgentDeps]):
function: SystemPromptFunc[AgentDeps]
dynamic: bool = False
_takes_ctx: bool = field(init=False)
_is_async: bool = field(init=False)

Expand Down
78 changes: 67 additions & 11 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ class Agent(Generic[AgentDeps, ResultData]):
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
_default_retries: int = dataclasses.field(repr=False)
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
_system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
repr=False
)
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
_max_result_retries: int = dataclasses.field(repr=False)
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
Expand Down Expand Up @@ -182,6 +185,7 @@ def __init__(
self._register_tool(Tool(tool))
self._deps_type = deps_type
self._system_prompt_functions = []
self._system_prompt_dynamic_functions = {}
self._max_result_retries = result_retries if result_retries is not None else retries
self._result_validators = []

Expand Down Expand Up @@ -535,17 +539,37 @@ def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
@overload
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...

@overload
def system_prompt(
self, /, *, dynamic: bool = False
) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...

def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self,
func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
/,
*,
dynamic: bool = False,
) -> (
Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
| _system_prompt.SystemPromptFunc[AgentDeps]
):
"""Decorator to register a system prompt function.
Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
Can decorate a sync or async functions.
The decorator can be used either bare (`agent.system_prompt`) or as a function call
(`agent.system_prompt(...)`), see the examples below.
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
the type of the function, see `tests/typed_agent.py` for tests.
Args:
func: The function to decorate
dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
Example:
```python
from pydantic_ai import Agent, RunContext
Expand All @@ -556,17 +580,27 @@ def system_prompt(
def simple_system_prompt() -> str:
return 'foobar'
@agent.system_prompt
@agent.system_prompt(dynamic=True)
async def async_system_prompt(ctx: RunContext[str]) -> str:
return f'{ctx.deps} is the best'
result = agent.run_sync('foobar', deps='spam')
print(result.data)
#> success (no tool calls)
```
"""
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
return func
if func is None:

def decorator(
func_: _system_prompt.SystemPromptFunc[AgentDeps],
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
self._system_prompt_functions.append(runner)
if dynamic:
self._system_prompt_dynamic_functions[func_.__qualname__] = runner
return func_

return decorator
else:
assert not dynamic, "dynamic can't be True in this case"
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
return func

@overload
def result_validator(
Expand Down Expand Up @@ -835,6 +869,23 @@ async def add_tool(tool: Tool[AgentDeps]) -> None:
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
)

async def _reevaluate_dynamic_prompts(
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
) -> None:
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
# Only proceed if there's at least one dynamic runner.
if self._system_prompt_dynamic_functions:
for msg in messages:
if isinstance(msg, _messages.ModelRequest):
for i, part in enumerate(msg.parts):
if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
# Look up the runner by its ref
if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
updated_part_content = await runner.run(run_context)
msg.parts[i] = _messages.SystemPromptPart(
updated_part_content, dynamic_ref=part.dynamic_ref
)

async def _prepare_messages(
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
) -> list[_messages.ModelMessage]:
Expand All @@ -850,8 +901,10 @@ async def _prepare_messages(
ctx_messages.used = True

if message_history:
# shallow copy messages
# Shallow copy messages
messages.extend(message_history)
# Reevaluate any dynamic system prompt parts
await self._reevaluate_dynamic_prompts(messages, run_context)
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
else:
parts = await self._sys_parts(run_context)
Expand Down Expand Up @@ -1088,7 +1141,10 @@ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
for sys_prompt_runner in self._system_prompt_functions:
prompt = await sys_prompt_runner.run(run_context)
messages.append(_messages.SystemPromptPart(prompt))
if sys_prompt_runner.dynamic:
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
else:
messages.append(_messages.SystemPromptPart(prompt))
return messages

def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class SystemPromptPart:
content: str
"""The content of the prompt."""

dynamic_ref: str | None = None
"""The ref of the dynamic system prompt function that generated this part.
Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
"""

part_kind: Literal['system-prompt'] = 'system-prompt'
"""Part type identifier, this is available on all parts as a discriminator."""

Expand Down
144 changes: 143 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,6 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
assert result.data == 'success (no tool calls)'
result2 = agent.run_sync('Hello 2')
assert result2.data == 'success (no tool calls)'

assert messages == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
Expand All @@ -1269,6 +1268,149 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
)


def test_dynamic_false_no_reevaluate(set_event_loop: None):
"""When dynamic is false (default), the system prompt is not reevaluated
i.e: SystemPromptPart(
content="A", <--- Remains the same when `message_history` is passed.
part_kind='system-prompt')
"""
agent = Agent('test', system_prompt='Foobar')

dynamic_value = 'A'

@agent.system_prompt
async def func() -> str:
return dynamic_value

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(content=dynamic_value, part_kind='system-prompt'),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)

dynamic_value = 'B'

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(
content='A', # Remains the same
part_kind='system-prompt',
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
ModelRequest(
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)


def test_dynamic_true_reevaluate_system_prompt(set_event_loop: None):
"""When dynamic is true, the system prompt is reevaluated
i.e: SystemPromptPart(
content="B", <--- Updated value
part_kind='system-prompt')
"""
agent = Agent('test', system_prompt='Foobar')

dynamic_value = 'A'

@agent.system_prompt(dynamic=True)
async def func():
return dynamic_value

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(
content=dynamic_value,
part_kind='system-prompt',
dynamic_ref=func.__qualname__,
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)

dynamic_value = 'B'

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(
content='B',
part_kind='system-prompt',
dynamic_ref=func.__qualname__,
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
ModelRequest(
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)


def test_capture_run_messages_tool_agent(set_event_loop: None) -> None:
agent_outer = Agent('test')
agent_inner = Agent(TestModel(custom_result_text='inner agent result'))
Expand Down

0 comments on commit ae82fa4

Please sign in to comment.