Skip to content

Commit

Permalink
Add support for usage limits (#409)
Browse files Browse the repository at this point in the history
Co-authored-by: sydney-runkle <[email protected]>
Co-authored-by: Sydney Runkle <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2024
1 parent 47d3e5c commit c23a286
Show file tree
Hide file tree
Showing 23 changed files with 439 additions and 63 deletions.
84 changes: 84 additions & 0 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,90 @@ You can also pass messages from previous runs to continue a conversation or prov

### Additional Configuration

#### Usage Limits

PydanticAI offers a [`settings.UsageLimits`][pydantic_ai.settings.UsageLimits] structure to help you limit your
usage (tokens and/or requests) on model runs.

You can apply these settings by passing the `usage_limits` argument to the `run{_sync,_stream}` functions.

Consider the following example, where we limit the number of response tokens:

```py
from pydantic_ai import Agent
from pydantic_ai.exceptions import UsageLimitExceeded
from pydantic_ai.settings import UsageLimits

agent = Agent('claude-3-5-sonnet-latest')

result_sync = agent.run_sync(
'What is the capital of Italy? Answer with just the city.',
usage_limits=UsageLimits(response_tokens_limit=10),
)
print(result_sync.data)
#> Rome
print(result_sync.usage())
"""
Usage(requests=1, request_tokens=62, response_tokens=1, total_tokens=63, details=None)
"""

try:
result_sync = agent.run_sync(
'What is the capital of Italy? Answer with a paragraph.',
usage_limits=UsageLimits(response_tokens_limit=10),
)
except UsageLimitExceeded as e:
print(e)
#> Exceeded the response_tokens_limit of 10 (response_tokens=32)
```

Restricting the number of requests can be useful in preventing infinite loops or excessive tool calling:

```py
from typing_extensions import TypedDict

from pydantic_ai import Agent, ModelRetry
from pydantic_ai.exceptions import UsageLimitExceeded
from pydantic_ai.settings import UsageLimits


class NeverResultType(TypedDict):
"""
Never ever coerce data to this type.
"""

never_use_this: str


agent = Agent(
'claude-3-5-sonnet-latest',
result_type=NeverResultType,
system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.',
)


@agent.tool_plain(retries=5) # (1)!
def infinite_retry_tool() -> int:
raise ModelRetry('Please try again.')


try:
result_sync = agent.run_sync(
'Begin infinite retry loop!', usage_limits=UsageLimits(request_limit=3) # (2)!
)
except UsageLimitExceeded as e:
print(e)
#> The next request would exceed the request_limit of 3
```

1. This tool has the ability to retry 5 times before erroring, simulating a tool that might get stuck in a loop.
2. This run will error after 3 requests, preventing the infinite tool calling.

!!! note
This is especially relevant if you're registered a lot of tools, `request_limit` can be used to prevent the model from choosing to make too many of these calls.

#### Model (Run) Settings

PydanticAI offers a [`settings.ModelSettings`][pydantic_ai.settings.ModelSettings] structure to help you fine tune your requests.
This structure allows you to configure common parameters that influence the model's behavior, such as `temperature`, `max_tokens`,
`timeout`, and more.
Expand Down
8 changes: 6 additions & 2 deletions docs/api/models/ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ result = agent.run_sync('Where were the olympics held in 2012?')
print(result.data)
#> city='London' country='United Kingdom'
print(result.usage())
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
"""
Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None)
"""
```

## Example using a remote server
Expand Down Expand Up @@ -60,7 +62,9 @@ result = agent.run_sync('Where were the olympics held in 2012?')
print(result.data)
#> city='London' country='United Kingdom'
print(result.usage())
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
"""
Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None)
"""
```

1. The name of the model running on the remote server
Expand Down
1 change: 1 addition & 0 deletions docs/api/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
inherited_members: true
members:
- ModelSettings
- UsageLimits
4 changes: 3 additions & 1 deletion docs/results.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ result = agent.run_sync('Where were the olympics held in 2012?')
print(result.data)
#> city='London' country='United Kingdom'
print(result.usage())
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
"""
Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None)
"""
```

_(This example is complete, it can be run "as is")_
Expand Down
14 changes: 12 additions & 2 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from importlib.metadata import version

from .agent import Agent
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
from .tools import RunContext, Tool

__all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
__all__ = (
'Agent',
'RunContext',
'Tool',
'AgentRunError',
'ModelRetry',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'UserError',
'__version__',
)
__version__ = version('pydantic_ai_slim')
33 changes: 25 additions & 8 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
result,
)
from .result import ResultData
from .settings import ModelSettings, merge_model_settings
from .settings import ModelSettings, UsageLimits, merge_model_settings
from .tools import (
AgentDeps,
RunContext,
Expand Down Expand Up @@ -191,6 +191,7 @@ async def run(
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt in async mode.
Expand All @@ -211,8 +212,9 @@ async def run(
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
model_settings: Optional settings to use for this model's request.
usage_limits: Optional limits on model request count or token usage.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
Returns:
The result of the run.
Expand All @@ -237,12 +239,14 @@ async def run(
for tool in self._function_tools.values():
tool.current_retry = 0

usage = result.Usage()

usage = result.Usage(requests=0)
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
usage_limits.check_before_request(usage)

run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
agent_model = await self._prepare_model(model_used, deps, messages)
Expand All @@ -254,6 +258,8 @@ async def run(

messages.append(model_response)
usage += request_usage
usage.requests += 1
usage_limits.check_tokens(request_usage)

with _logfire.span('handle model response', run_step=run_step) as handle_span:
final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
Expand Down Expand Up @@ -284,6 +290,7 @@ def run_sync(
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.
Expand All @@ -308,8 +315,9 @@ async def main():
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
model_settings: Optional settings to use for this model's request.
usage_limits: Optional limits on model request count or token usage.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
Returns:
The result of the run.
Expand All @@ -322,8 +330,9 @@ async def main():
message_history=message_history,
model=model,
deps=deps,
infer_name=False,
model_settings=model_settings,
usage_limits=usage_limits,
infer_name=False,
)
)

Expand All @@ -336,6 +345,7 @@ async def run_stream(
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
infer_name: bool = True,
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
"""Run the agent with a user prompt in async mode, returning a streamed response.
Expand All @@ -357,8 +367,9 @@ async def main():
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
model_settings: Optional settings to use for this model's request.
usage_limits: Optional limits on model request count or token usage.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
Returns:
The result of the run.
Expand Down Expand Up @@ -387,16 +398,19 @@ async def main():

usage = result.Usage()
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
run_step += 1
usage_limits.check_before_request(usage)

with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
agent_model = await self._prepare_model(model_used, deps, messages)

with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
async with agent_model.request_stream(messages, model_settings) as model_response:
usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
# in the traditional way
Expand Down Expand Up @@ -435,6 +449,7 @@ async def on_complete():
messages,
new_message_index,
usage,
usage_limits,
result_stream,
self._result_schema,
deps,
Expand All @@ -456,7 +471,9 @@ async def on_complete():
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
handle_span.message = f'handle model response -> {tool_responses_str}'
# the model_response should have been fully streamed by now, we can add its usage
usage += model_response.usage()
model_response_usage = model_response.usage()
usage += model_response_usage
usage_limits.check_tokens(usage)

@contextmanager
def override(
Expand Down
22 changes: 20 additions & 2 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json

__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'


class ModelRetry(Exception):
Expand Down Expand Up @@ -30,7 +30,25 @@ def __init__(self, message: str):
super().__init__(message)


class UnexpectedModelBehavior(RuntimeError):
class AgentRunError(RuntimeError):
"""Base class for errors occurring during an agent run."""

message: str
"""The error message."""

def __init__(self, message: str):
self.message = message
super().__init__(message)

def __str__(self) -> str:
return self.message


class UsageLimitExceeded(AgentRunError):
"""Error raised when a Model's usage exceeds the specified limits."""


class UnexpectedModelBehavior(AgentRunError):
"""Error caused by unexpected Model behavior, e.g. an unexpected response code."""

message: str
Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,11 @@ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]
MessageParam(
role='user',
content=[
ToolUseBlockParam(
id=_guard_tool_call_id(t=part, model_source='Anthropic'),
input=part.model_response(),
name=part.tool_name,
type='tool_use',
ToolResultBlockParam(
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
type='tool_result',
content=part.model_response(),
is_error=True,
),
],
)
Expand Down
14 changes: 7 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get(self, *, final: bool = False) -> ModelResponse:
return ModelResponse(calls, timestamp=self._timestamp)

def usage(self) -> result.Usage:
return result.Usage()
return _estimate_usage([self.get()])

def timestamp(self) -> datetime:
return self._timestamp
Expand All @@ -255,24 +255,24 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
if isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, (SystemPromptPart, UserPromptPart)):
request_tokens += _string_usage(part.content)
request_tokens += _estimate_string_usage(part.content)
elif isinstance(part, ToolReturnPart):
request_tokens += _string_usage(part.model_response_str())
request_tokens += _estimate_string_usage(part.model_response_str())
elif isinstance(part, RetryPromptPart):
request_tokens += _string_usage(part.model_response())
request_tokens += _estimate_string_usage(part.model_response())
else:
assert_never(part)
elif isinstance(message, ModelResponse):
for part in message.parts:
if isinstance(part, TextPart):
response_tokens += _string_usage(part.content)
response_tokens += _estimate_string_usage(part.content)
elif isinstance(part, ToolCallPart):
call = part
if isinstance(call.args, ArgsJson):
args_str = call.args.args_json
else:
args_str = pydantic_core.to_json(call.args.args_dict).decode()
response_tokens += 1 + _string_usage(args_str)
response_tokens += 1 + _estimate_string_usage(args_str)
else:
assert_never(part)
else:
Expand All @@ -282,5 +282,5 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
)


def _string_usage(content: str) -> int:
def _estimate_string_usage(content: str) -> int:
return len(re.split(r'[\s",.:]+', content))
Loading

0 comments on commit c23a286

Please sign in to comment.