Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
josead committed Dec 30, 2024
1 parent 38c5f62 commit f2b793f
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 135 deletions.
15 changes: 9 additions & 6 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,11 +558,13 @@ async def async_system_prompt(ctx: RunContext[str]) -> str:
```
"""
if func is None:

def decorator(
func_: _system_prompt.SystemPromptFunc[AgentDeps],
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func_, dynamic=dynamic))
return func_

return decorator
else:
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
Expand Down Expand Up @@ -838,7 +840,9 @@ async def add_tool(tool: Tool[AgentDeps]) -> None:
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
)

async def _prepare_messages(self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]) -> list[_messages.ModelMessage]:
async def _prepare_messages(
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
) -> list[_messages.ModelMessage]:
try:
messages = _messages_ctx_var.get()
except LookupError:
Expand All @@ -853,22 +857,21 @@ async def _prepare_messages(self, user_prompt: str, message_history: list[_messa
if message_history:
# shallow copy messages
messages.extend(message_history)

# If there are any dynamic system prompts, we need to reevaluate them
if any(runner.dynamic for runner in self._system_prompt_functions):
# Get fresh system prompts
new_sys_parts = await self._sys_parts(run_context)

# Replace the system prompts in the existing messages
for msg in messages:
if isinstance(msg, _messages.ModelRequest):
# Keep non-system parts and add new system parts
non_system_parts = [
part for part in msg.parts
if not isinstance(part, _messages.SystemPromptPart)
part for part in msg.parts if not isinstance(part, _messages.SystemPromptPart)
]
msg.parts = new_sys_parts + non_system_parts

messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
else:
parts = await self._sys_parts(run_context)
Expand Down
233 changes: 104 additions & 129 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,14 +1233,14 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:


def test_dynamic_false_no_reevaluate(set_event_loop: None):
""" When dynamic is false (default), the system prompt is not reevaluated
"""When dynamic is false (default), the system prompt is not reevaluated
i.e: SystemPromptPart(
content="A", <--- Remains the same when `message_history` is passed.
part_kind='system-prompt')
"""
agent = Agent('test', system_prompt='Foobar')
dynamic_value = "A"

dynamic_value = 'A'

@agent.system_prompt
async def func(ctx):
Expand All @@ -1252,74 +1252,68 @@ async def func(ctx):
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content=dynamic_value,
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(content=dynamic_value, part_kind='system-prompt'),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)

dynamic_value = "B"
dynamic_value = 'B'

@agent.system_prompt
async def func_two(ctx):
return dynamic_value + "!"
return dynamic_value + '!'

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content="A", #Remains the same
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
content='A', # Remains the same
part_kind='system-prompt',
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response'),
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
ModelRequest(
parts=[
UserPromptPart(
content='World',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)


def test_dynamic_true_reevaluate_system_prompt(set_event_loop: None):
""" When dynamic is true, the system prompt is reevaluated
"""When dynamic is true, the system prompt is reevaluated
i.e: SystemPromptPart(
content="B", <--- Updated value
part_kind='system-prompt')
"""
agent = Agent('test', system_prompt='Foobar')
dynamic_value = "A"

dynamic_value = 'A'

@agent.system_prompt(dynamic=True)
async def func(ctx):
Expand All @@ -1331,72 +1325,64 @@ async def func(ctx):
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content=dynamic_value,
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(content=dynamic_value, part_kind='system-prompt'),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)

dynamic_value = "B"
dynamic_value = 'B'

@agent.system_prompt
async def func_two(ctx):
return "This is a new prompt, but it wont reach the model"
return 'This is a new prompt, but it wont reach the model'

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content="B", # Updated value
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
content='B', # Updated value
part_kind='system-prompt',
),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response'),
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
ModelRequest(
parts=[
UserPromptPart(
content='World',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)



def test_dynamic_true_evaluate_new_system_prompt(set_event_loop: None):
""" When new system prompt added with `dynamic` = True, they will be evaluated and added to the system parts, (besides the reevaluated ones)."""
"""When new system prompt added with `dynamic` = True, they will be evaluated and added to the system parts, (besides the reevaluated ones)."""
agent = Agent('test', system_prompt='Foobar')
dynamic_value = "A"

dynamic_value = 'A'

@agent.system_prompt(dynamic=True)
async def func(ctx):
Expand All @@ -1408,66 +1394,55 @@ async def func(ctx):
[
ModelRequest(
parts=[
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content=dynamic_value,
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(content=dynamic_value, part_kind='system-prompt'),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)

dynamic_value = "B"
dynamic_value = 'B'

@agent.system_prompt(dynamic=True)
async def func_two(ctx):
return "This is a new prompt, and model will know"
return 'This is a new prompt, and model will know'

res_two = agent.run_sync('World', message_history=res.all_messages())

assert res_two.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(content='Foobar', part_kind='system-prompt'),
SystemPromptPart(
content='Foobar',
part_kind='system-prompt'),
SystemPromptPart(
content="B", # Updated value since dirty
part_kind='system-prompt'),
SystemPromptPart(
content="This is a new prompt, and model will know",
part_kind='system-prompt'),
UserPromptPart(
content='Hello',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
content='B', # Updated value since dirty
part_kind='system-prompt',
),
SystemPromptPart(content='This is a new prompt, and model will know', part_kind='system-prompt'),
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt'),
],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response'),
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
ModelRequest(
parts=[
UserPromptPart(
content='World',
timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')
], kind='request'),
parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc), part_kind='user-prompt')],
kind='request',
),
ModelResponse(
parts=[
TextPart(content='success (no tool calls)', part_kind='text')
],
timestamp=IsNow(tz=timezone.utc), kind='response')
parts=[TextPart(content='success (no tool calls)', part_kind='text')],
timestamp=IsNow(tz=timezone.utc),
kind='response',
),
]
)


0 comments on commit f2b793f

Please sign in to comment.