Skip to content

Commit

Permalink
extend RunContext (#570)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Dec 30, 2024
1 parent 2289879 commit fde6c9a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 33 deletions.
45 changes: 24 additions & 21 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

# while waiting for https://github.com/pydantic/logfire/issues/745
try:
import logfire._internal.stack_info
except ImportError:
pass
else:
from pathlib import Path

logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)

NoneType = type(None)
EndStrategy = Literal['early', 'exhaustive']
"""The strategy for handling multiple tool calls when a final result is found.
Expand Down Expand Up @@ -215,7 +225,7 @@ async def run(
"""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
model_used, mode_selection = await self._get_model(model)
model_used = await self._get_model(model)

deps = self._get_deps(deps)
new_message_index = len(message_history) if message_history else 0
Expand All @@ -224,11 +234,10 @@ async def run(
'{agent_name} run {prompt=}',
prompt=user_prompt,
agent=self,
mode_selection=mode_selection,
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
messages = await self._prepare_messages(user_prompt, message_history, run_context)
run_context.messages = messages

Expand All @@ -238,15 +247,14 @@ async def run(
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
usage_limits.check_before_request(run_context.usage)

run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
run_context.run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request', run_step=run_step) as model_req_span:
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
model_response, request_usage = await agent_model.request(messages, model_settings)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('usage', request_usage)
Expand All @@ -255,7 +263,7 @@ async def run(
run_context.usage.incr(request_usage, requests=1)
usage_limits.check_tokens(run_context.usage)

with _logfire.span('handle model response', run_step=run_step) as handle_span:
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
final_result, tool_responses = await self._handle_model_response(model_response, run_context)

if tool_responses:
Expand Down Expand Up @@ -377,7 +385,7 @@ async def main():
# f_back because `asynccontextmanager` adds one frame
if frame := inspect.currentframe(): # pragma: no branch
self._infer_name(frame.f_back)
model_used, mode_selection = await self._get_model(model)
model_used = await self._get_model(model)

deps = self._get_deps(deps)
new_message_index = len(message_history) if message_history else 0
Expand All @@ -386,11 +394,10 @@ async def main():
'{agent_name} run stream {prompt=}',
prompt=user_prompt,
agent=self,
mode_selection=mode_selection,
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
messages = await self._prepare_messages(user_prompt, message_history, run_context)
run_context.messages = messages

Expand All @@ -400,15 +407,14 @@ async def main():
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
run_step += 1
run_context.run_step += 1
usage_limits.check_before_request(run_context.usage)

with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
async with agent_model.request_stream(messages, model_settings) as model_response:
run_context.usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
Expand Down Expand Up @@ -781,14 +787,14 @@ def _register_tool(self, tool: Tool[AgentDeps]) -> None:

self._function_tools[tool.name] = tool

async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
"""Create a model configured for this agent.
Args:
model: model to use for this run, required if `model` was not set when creating the agent.
Returns:
a tuple of `(model used, how the model was selected)`
The model used
"""
model_: models.Model
if some_model := self._override_model:
Expand All @@ -799,18 +805,15 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
'(Even when `override(model=...)` is customizing the model that will actually be called)'
)
model_ = some_model.value
mode_selection = 'override-model'
elif model is not None:
model_ = models.infer_model(model)
mode_selection = 'custom'
elif self.model is not None:
# noinspection PyTypeChecker
model_ = self.model = models.infer_model(self.model)
mode_selection = 'from-agent'
else:
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')

return model_, mode_selection
return model_

async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
"""Build tools and create an agent model."""
Expand Down
16 changes: 10 additions & 6 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,20 @@ class RunContext(Generic[AgentDeps]):

deps: AgentDeps
"""Dependencies for the agent."""
retry: int
"""Number of retries so far."""
messages: list[_messages.ModelMessage]
"""Messages exchanged in the conversation so far."""
tool_name: str | None
"""Name of the tool being called."""
model: models.Model
"""The model used in this run."""
usage: Usage
"""LLM usage associated with the run."""
prompt: str
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
"""Messages exchanged in the conversation so far."""
tool_name: str | None = None
"""Name of the tool being called."""
retry: int = 0
"""Number of retries so far."""
run_step: int = 0
"""The current step in the run."""

def replace_with(
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
Expand Down
10 changes: 4 additions & 6 deletions tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ async def my_ret(x: int) -> str:
)
assert summary.attributes[0] == snapshot(
{
'code.filepath': 'agent.py',
'code.function': 'run',
'code.filepath': 'test_logfire.py',
'code.function': 'test_logfire',
'code.lineno': 123,
'prompt': 'Hello',
'agent': IsJson(
Expand All @@ -111,7 +111,6 @@ async def my_ret(x: int) -> str:
'model_settings': None,
}
),
'mode_selection': 'from-agent',
'model_name': 'test-model',
'agent_name': 'my_agent',
'logfire.msg_template': '{agent_name} run {prompt=}',
Expand Down Expand Up @@ -176,7 +175,6 @@ async def my_ret(x: int) -> str:
'model': {'type': 'object', 'title': 'TestModel', 'x-python-datatype': 'dataclass'}
},
},
'mode_selection': {},
'model_name': {},
'agent_name': {},
'all_messages': {
Expand Down Expand Up @@ -263,8 +261,8 @@ async def my_ret(x: int) -> str:
)
assert summary.attributes[1] == snapshot(
{
'code.filepath': 'agent.py',
'code.function': 'run',
'code.filepath': 'test_logfire.py',
'code.function': 'test_logfire',
'code.lineno': IsInt(),
'run_step': 1,
'logfire.msg_template': 'preparing model and tools {run_step=}',
Expand Down

0 comments on commit fde6c9a

Please sign in to comment.