Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds dynamic to system_prompt decorator, allowing reevaluation #560

Merged
merged 32 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c248ade
Adds dynamic to system_prompt decorator, allowing reevaluation
josead Dec 28, 2024
71c8808
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Dec 30, 2024
f63ede3
lint fix
josead Dec 30, 2024
11b5abb
removing useless overload
josead Dec 30, 2024
38c5f62
Adds tests better naming
josead Dec 30, 2024
f2b793f
lint
josead Dec 30, 2024
b56bb1c
removes unused
josead Dec 30, 2024
0453499
Adds dynamic id to parts, referencing the runner that created it.
josead Jan 1, 2025
ade332e
Merge branch 'main' into main
josead Jan 1, 2025
2e63819
fix tests
josead Jan 1, 2025
49bf36f
fix lint
josead Jan 1, 2025
e8f798e
Fix merge
josead Jan 1, 2025
01ab9f4
Fixing overload
josead Jan 2, 2025
20cd93a
Update pydantic_ai_slim/pydantic_ai/messages.py
josead Jan 2, 2025
59f0ef4
Update pydantic_ai_slim/pydantic_ai/messages.py
josead Jan 2, 2025
2c84e65
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
260a610
Removes unused system prompts
josead Jan 2, 2025
2a73e13
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
78db2c2
Adds changes to use qual name for ref
josead Jan 2, 2025
3c1cc83
lint changes, make
josead Jan 2, 2025
08f041d
Adds default value and fix tests
josead Jan 2, 2025
5a10407
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
440b443
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
3d14344
Update pydantic_ai_slim/pydantic_ai/agent.py
josead Jan 2, 2025
e69a351
adds callable and dynamic assert on func none
josead Jan 2, 2025
feda9ff
adds the assert to the correct part
josead Jan 2, 2025
9da26c5
Adds Callable as return type
josead Jan 2, 2025
68380b8
Adds dynamic ref in system prompt
josead Jan 2, 2025
740ea26
Fix examples tests
josead Jan 2, 2025
956f143
Changes on docs to support dynamic
josead Jan 2, 2025
f15ba42
Adds correct overload
josead Jan 5, 2025
90a6069
tweaks to docs
samuelcolvin Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
@dataclass
class SystemPromptRunner(Generic[AgentDeps]):
function: SystemPromptFunc[AgentDeps]
dynamic: bool = False
_takes_ctx: bool = field(init=False)
_is_async: bool = field(init=False)

Expand Down
81 changes: 68 additions & 13 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,16 +535,38 @@ def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
@overload
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...

@overload
def system_prompt(
self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self, func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None, /, *, dynamic: bool = False
josead marked this conversation as resolved.
Show resolved Hide resolved
) -> Any:
if func is None:

def decorator(
func_: _system_prompt.SystemPromptFunc[AgentDeps],
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func_, dynamic=dynamic))
return func_

return decorator
else:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
return func

def system_prompt(
self,
func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
/,
*,
dynamic: bool = False,
) -> Any:
"""Decorator to register a system prompt function.

Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
Can decorate a sync or async functions.

Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
the type of the function, see `tests/typed_agent.py` for tests.
Args:
func: The function to decorate
dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided.

Example:
```python
Expand All @@ -556,17 +578,23 @@ def system_prompt(
def simple_system_prompt() -> str:
return 'foobar'

@agent.system_prompt
@agent.system_prompt(dynamic=True)
async def async_system_prompt(ctx: RunContext[str]) -> str:
return f'{ctx.deps} is the best'

result = agent.run_sync('foobar', deps='spam')
print(result.data)
#> success (no tool calls)
```
"""
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
return func
if func is None:

def decorator(
func_: _system_prompt.SystemPromptFunc[AgentDeps],
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func_, dynamic=dynamic))
return func_

return decorator
else:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
return func

@overload
def result_validator(
Expand Down Expand Up @@ -835,6 +863,25 @@ async def add_tool(tool: Tool[AgentDeps]) -> None:
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
)

async def _reevaluate_dynamic_prompts(
self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
) -> None:
"""Reevaluate any DynamicSystemPromptPart in the provided messages by running the associated runner function."""
josead marked this conversation as resolved.
Show resolved Hide resolved
# Pre-map runner IDs to runners for efficient lookups (instead of re-looping each time).
runner_map = {id(runner): runner for runner in self._system_prompt_functions}
josead marked this conversation as resolved.
Show resolved Hide resolved

# Only proceed if there's at least one dynamic runner.
if any(runner.dynamic for runner in self._system_prompt_functions):
for msg in messages:
if isinstance(msg, _messages.ModelRequest):
for i, part in enumerate(msg.parts):
if isinstance(part, _messages.DynamicSystemPromptPart):
# Look up the runner by its ID (default 0 in case part.ref is None).
matching_runner = runner_map.get(part.ref or 0)
if matching_runner is not None:
josead marked this conversation as resolved.
Show resolved Hide resolved
updated_part_content = await matching_runner.run(run_context)
msg.parts[i] = _messages.DynamicSystemPromptPart(updated_part_content, ref=part.ref)

async def _prepare_messages(
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
) -> list[_messages.ModelMessage]:
Expand All @@ -850,8 +897,13 @@ async def _prepare_messages(
ctx_messages.used = True

if message_history:
# shallow copy messages
# Shallow copy messages
messages.extend(message_history)

# Reevaluate any dynamic system prompt parts (now done by our extracted method)
josead marked this conversation as resolved.
Show resolved Hide resolved
await self._reevaluate_dynamic_prompts(messages, run_context)

# Finally, append the new user prompt
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
else:
parts = await self._sys_parts(run_context)
Expand Down Expand Up @@ -1088,7 +1140,10 @@ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
for sys_prompt_runner in self._system_prompt_functions:
prompt = await sys_prompt_runner.run(run_context)
messages.append(_messages.SystemPromptPart(prompt))
if sys_prompt_runner.dynamic:
messages.append(_messages.DynamicSystemPromptPart(prompt, ref=id(sys_prompt_runner)))
else:
messages.append(_messages.SystemPromptPart(prompt))
return messages

def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
Expand Down
11 changes: 11 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ class SystemPromptPart:
"""Part type identifier, this is available on all parts as a discriminator."""


@dataclass
class DynamicSystemPromptPart(SystemPromptPart):
"""A system prompt that is generated dynamically.

Same as SystemPromptPart, but its content is regenerated on each run.
josead marked this conversation as resolved.
Show resolved Hide resolved
"""

ref: int | None = None
josead marked this conversation as resolved.
Show resolved Hide resolved
"""The ref ID of the system prompt function that generated this part."""


@dataclass
class UserPromptPart:
"""A user prompt, generally written by the end user.
Expand Down
157 changes: 156 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pydantic_ai.messages import (
ArgsDict,
ArgsJson,
DynamicSystemPromptPart,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -1260,7 +1261,6 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
assert result.data == 'success (no tool calls)'
result2 = agent.run_sync('Hello 2')
assert result2.data == 'success (no tool calls)'

assert messages == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
Expand All @@ -1269,6 +1269,161 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
)


def test_dynamic_false_no_reevaluate(set_event_loop: None):
"""When dynamic is false (default), the system prompt is not reevaluated
i.e: SystemPromptPart(
content="A", <--- Remains the same when `message_history` is passed.
part_kind='system-prompt')
"""
agent = Agent('test', system_prompt='Foobar')

dynamic_value = 'A'

@agent.system_prompt
async def func() -> str:
return dynamic_value

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(content=dynamic_value, part_kind='system-prompt'),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)

dynamic_value = 'B'

@agent.system_prompt
async def func_two():
return 'This is not added'
josead marked this conversation as resolved.
Show resolved Hide resolved

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(
content='A', # Remains the same
part_kind='system-prompt',
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
ModelRequest(
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)


def test_dynamic_true_reevaluate_system_prompt(set_event_loop: None):
"""When dynamic is true, the system prompt is reevaluated
i.e: SystemPromptPart(
content="B", <--- Updated value
part_kind='system-prompt')
"""
agent = Agent('test', system_prompt='Foobar')

dynamic_value = 'A'

@agent.system_prompt(dynamic=True)
async def func():
return dynamic_value

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
DynamicSystemPromptPart(
content=dynamic_value,
part_kind='system-prompt',
ref=id(agent._system_prompt_functions[0]), # type: ignore
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)

dynamic_value = 'B'

@agent.system_prompt
async def func_two():
josead marked this conversation as resolved.
Show resolved Hide resolved
return 'This is a new prompt, but it wont reach the model'

@agent.system_prompt(dynamic=True)
async def func_two_dynamic():
josead marked this conversation as resolved.
Show resolved Hide resolved
return 'This is a new prompt, but it wont reach the model, even though is dynamic'

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
DynamicSystemPromptPart(
content='B',
part_kind='system-prompt',
ref=id(agent._system_prompt_functions[0]), # type: ignore
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
ModelRequest(
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
kind='request',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)


def test_capture_run_messages_tool_agent(set_event_loop: None) -> None:
agent_outer = Agent('test')
agent_inner = Agent(TestModel(custom_result_text='inner agent result'))
Expand Down
Loading