From 4537c07117a1ebbf8ef9eb3af775f76c00bd2b34 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 31 Dec 2024 17:20:19 +0000 Subject: [PATCH] Multi-agent application documentation (#541) Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- docs/agents.md | 6 +- docs/api/usage.md | 3 + docs/dependencies.md | 40 +- docs/examples/flight-booking.md | 41 +++ docs/extra/tweaks.css | 4 + docs/multi-agent-applications.md | 344 ++++++++++++++++++ .../pydantic_ai_examples/flight_booking.py | 242 ++++++++++++ mkdocs.yml | 3 + pydantic_ai_slim/pydantic_ai/agent.py | 28 +- pydantic_ai_slim/pydantic_ai/result.py | 161 ++++---- pydantic_ai_slim/pydantic_ai/settings.py | 62 +--- pydantic_ai_slim/pydantic_ai/usage.py | 114 ++++++ tests/test_agent.py | 42 ++- tests/test_examples.py | 44 ++- tests/test_usage_limits.py | 3 +- 15 files changed, 945 insertions(+), 192 deletions(-) create mode 100644 docs/api/usage.md create mode 100644 docs/examples/flight-booking.md create mode 100644 docs/multi-agent-applications.md create mode 100644 examples/pydantic_ai_examples/flight_booking.py create mode 100644 pydantic_ai_slim/pydantic_ai/usage.py diff --git a/docs/agents.md b/docs/agents.md index aa022288..0009564d 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -108,7 +108,7 @@ You can also pass messages from previous runs to continue a conversation or prov #### Usage Limits -PydanticAI offers a [`settings.UsageLimits`][pydantic_ai.settings.UsageLimits] structure to help you limit your +PydanticAI offers a [`UsageLimits`][pydantic_ai.usage.UsageLimits] structure to help you limit your usage (tokens and/or requests) on model runs. You can apply these settings by passing the `usage_limits` argument to the `run{_sync,_stream}` functions. @@ -118,7 +118,7 @@ Consider the following example, where we limit the number of response tokens: ```py from pydantic_ai import Agent from pydantic_ai.exceptions import UsageLimitExceeded -from pydantic_ai.settings import UsageLimits +from pydantic_ai.usage import UsageLimits agent = Agent('claude-3-5-sonnet-latest') @@ -150,7 +150,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, ModelRetry from pydantic_ai.exceptions import UsageLimitExceeded -from pydantic_ai.settings import UsageLimits +from pydantic_ai.usage import UsageLimits class NeverResultType(TypedDict): diff --git a/docs/api/usage.md b/docs/api/usage.md new file mode 100644 index 00000000..71e16208 --- /dev/null +++ b/docs/api/usage.md @@ -0,0 +1,3 @@ +# `pydantic_ai.usage` + +::: pydantic_ai.usage diff --git a/docs/dependencies.md b/docs/dependencies.md index ad6cc535..be533614 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -275,6 +275,8 @@ async def application_code(prompt: str) -> str: # (3)! 3. Application code that calls the agent, in a real application this might be an API endpoint. 4. Call the agent from within the application code, in a real application this call might be deep within a call stack. Note `app_deps` here will NOT be used when deps are overridden. +_(This example is complete, it can be run "as is")_ + ```python {title="test_joke_app.py" hl_lines="10-12" call_name="test_application_code"} from joke_app import MyDeps, application_code, joke_agent @@ -296,44 +298,6 @@ async def test_application_code(): 3. Override the dependencies of the agent for the duration of the `with` block, `test_deps` will be used when the agent is run. 4. Now we can safely call our application code, the agent will use the overridden dependencies. -## Agents as dependencies of other Agents - -Since dependencies can be any python type, and agents are just python objects, agents can be dependencies of other agents. - -```python {title="agents_as_dependencies.py"} -from dataclasses import dataclass - -from pydantic_ai import Agent, RunContext - - -@dataclass -class MyDeps: - factory_agent: Agent[None, list[str]] - - -joke_agent = Agent( - 'openai:gpt-4o', - deps_type=MyDeps, - system_prompt=( - 'Use the "joke_factory" to generate some jokes, then choose the best. ' - 'You must return just a single joke.' - ), -) - -factory_agent = Agent('gemini-1.5-pro', result_type=list[str]) - - -@joke_agent.tool -async def joke_factory(ctx: RunContext[MyDeps], count: int) -> str: - r = await ctx.deps.factory_agent.run(f'Please generate {count} jokes.') - return '\n'.join(r.data) - - -result = joke_agent.run_sync('Tell me a joke.', deps=MyDeps(factory_agent)) -print(result.data) -#> Did you hear about the toothpaste scandal? They called it Colgate. -``` - ## Examples The following examples demonstrate how to use dependencies in PydanticAI: diff --git a/docs/examples/flight-booking.md b/docs/examples/flight-booking.md new file mode 100644 index 00000000..7e6f1921 --- /dev/null +++ b/docs/examples/flight-booking.md @@ -0,0 +1,41 @@ +Example of a multi-agent flow where one agent delegates work to another, then hands off control to a third agent. + +Demonstrates: + +* [agent delegation](../multi-agent-applications.md#agent-delegation) +* [programmatic agent hand-off](../multi-agent-applications.md#programmatic-agent-hand-off) +* [usage limits](../agents.md#usage-limits) + +In this scenario, a group of agents work together to find the best flight for a user. + +The control flow for this example can be summarised as follows: + +```mermaid +graph TD + START --> search_agent("search agent") + search_agent --> extraction_agent("extraction agent") + extraction_agent --> search_agent + search_agent --> human_confirm("human confirm") + human_confirm --> search_agent + search_agent --> FAILED + human_confirm --> find_seat_function("find seat function") + find_seat_function --> human_seat_choice("human seat choice") + human_seat_choice --> find_seat_agent("find seat agent") + find_seat_agent --> find_seat_function + find_seat_function --> buy_flights("buy flights") + buy_flights --> SUCCESS +``` + +## Running the Example + +With [dependencies installed and environment variables set](./index.md#usage), run: + +```bash +python/uv-run -m pydantic_ai_examples.flight_booking +``` + +## Example Code + +```python {title="flight_booking.py"} +#! examples/pydantic_ai_examples/flight_booking.py +``` diff --git a/docs/extra/tweaks.css b/docs/extra/tweaks.css index 2e3206f8..c7d69224 100644 --- a/docs/extra/tweaks.css +++ b/docs/extra/tweaks.css @@ -55,3 +55,7 @@ img.index-header { min-height: 120px; margin-bottom: 10px; } + +.mermaid { + text-align: center; +} diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md new file mode 100644 index 00000000..ae66ec4e --- /dev/null +++ b/docs/multi-agent-applications.md @@ -0,0 +1,344 @@ +from pydantic_ai_examples.sql_gen import system_prompt + +# Multi-agent Applications + +There are roughly four levels of complexity when building applications with PydanticAI: + +1. Single agent workflows — what most of the `pydantic_ai` documentation covers +2. [Agent delegation](#agent-delegation) — agents using another agent via tools +3. [Programmatic agent hand-off](#programmatic-agent-hand-off) — one agent runs, then application code calls another agent +4. [Graph based control flow](#pydanticai-graphs) — for the most complex cases, a graph-based state machine can be used to control the execution of multiple agents + +Of course, you can combine multiple strategies in a single application. + +## Agent delegation + +"Agent delegation" refers to the scenario where an agent delegates work to another agent, then takes back control when the delegate agent (the agent called from within a tool) finishes. + +Since agents are stateless and designed to be global, you do not need to include the agent itself in agent [dependencies](dependencies.md). + +You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.Agent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. + +!!! note "Multiple models" + Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.result.RunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs. + +```python {title="agent_delegation_simple.py"} +from pydantic_ai import Agent, RunContext +from pydantic_ai.usage import UsageLimits + +joke_selection_agent = Agent( # (1)! + 'openai:gpt-4o', + system_prompt=( + 'Use the `joke_factory` to generate some jokes, then choose the best. ' + 'You must return just a single joke.' + ), +) +joke_generation_agent = Agent('gemini-1.5-flash', result_type=list[str]) # (2)! + + +@joke_selection_agent.tool +async def joke_factory(ctx: RunContext[None], count: int) -> list[str]: + r = await joke_generation_agent.run( # (3)! + f'Please generate {count} jokes.', + usage=ctx.usage, # (4)! + ) + return r.data # (5)! + + +result = joke_selection_agent.run_sync( + 'Tell me a joke.', + usage_limits=UsageLimits(request_limit=5, total_tokens_limit=300), +) +print(result.data) +#> Did you hear about the toothpaste scandal? They called it Colgate. +print(result.usage()) +""" +Usage( + requests=3, request_tokens=204, response_tokens=24, total_tokens=228, details=None +) +""" +``` + +1. The "parent" or controlling agent. +2. The "delegate" agent, which is called from within a tool of the parent agent. +3. Call the delegate agent from within a tool of the parent agent. +4. Pass the usage from the parent agent to the delegate agent so the final [`result.usage()`][pydantic_ai.result.RunResult.usage] includes the usage from both agents. +5. Since the function returns `#!python list[str]`, and the `result_type` of `joke_generation_agent` is also `#!python list[str]`, we can simply return `#!python r.data` from the tool. + +_(This example is complete, it can be run "as is")_ + +The control flow for this example is pretty simple and can be summarised as follows: + +```mermaid +graph TD + START --> joke_selection_agent + joke_selection_agent --> joke_factory["joke_factory (tool)"] + joke_factory --> joke_generation_agent + joke_generation_agent --> joke_factory + joke_factory --> joke_selection_agent + joke_selection_agent --> END +``` + +### Agent delegation and dependencies + +Generally the delegate agent needs to either have the same [dependencies](dependencies.md) as the calling agent, or dependencies which are a subset of the calling agent's dependencies. + +!!! info "Initializing dependencies" + We say "generally" above since there's nothing to stop you initializing dependencies within a tool call and therefore using interdependencies in a delegate agent that are not available on the parent, this should often be avoided since it can be significantly slower than reusing connections etc. from the parent agent. + +```python {title="agent_delegation_deps.py"} +from dataclasses import dataclass + +import httpx + +from pydantic_ai import Agent, RunContext + + +@dataclass +class ClientAndKey: # (1)! + http_client: httpx.AsyncClient + api_key: str + + +joke_selection_agent = Agent( + 'openai:gpt-4o', + deps_type=ClientAndKey, # (2)! + system_prompt=( + 'Use the `joke_factory` tool to generate some jokes on the given subject, ' + 'then choose the best. You must return just a single joke.' + ), +) +joke_generation_agent = Agent( + 'gemini-1.5-flash', + deps_type=ClientAndKey, # (4)! + result_type=list[str], + system_prompt=( + 'Use the "get_jokes" tool to get some jokes on the given subject, ' + 'then extract each joke into a list.' + ), +) + + +@joke_selection_agent.tool +async def joke_factory(ctx: RunContext[ClientAndKey], count: int) -> list[str]: + r = await joke_generation_agent.run( + f'Please generate {count} jokes.', + deps=ctx.deps, # (3)! + usage=ctx.usage, + ) + return r.data + + +@joke_generation_agent.tool # (5)! +async def get_jokes(ctx: RunContext[ClientAndKey], count: int) -> str: + response = await ctx.deps.http_client.get( + 'https://example.com', + params={'count': count}, + headers={'Authorization': f'Bearer {ctx.deps.api_key}'}, + ) + response.raise_for_status() + return response.text + + +async def main(): + async with httpx.AsyncClient() as client: + deps = ClientAndKey(client, 'foobar') + result = await joke_selection_agent.run('Tell me a joke.', deps=deps) + print(result.data) + #> Did you hear about the toothpaste scandal? They called it Colgate. + print(result.usage()) # (6)! + """ + Usage( + requests=4, + request_tokens=310, + response_tokens=32, + total_tokens=342, + details=None, + ) + """ +``` + +1. Define a dataclass to hold the client and API key dependencies. +2. Set the `deps_type` of the calling agent — `joke_selection_agent` here. +3. Pass the dependencies to the delegate agent's run method within the tool call. +4. Also set the `deps_type` of the delegate agent — `joke_generation_agent` here. +5. Define a tool on the delegate agent that uses the dependencies to make an HTTP request. +6. Usage now includes 4 requests — 2 from the calling agent and 2 from the delegate agent. + +_(This example is complete, it can be run "as is")_ + +This example shows how even a fairly simple agent delegation can lead to a complex control flow: + +```mermaid +graph TD + START --> joke_selection_agent + joke_selection_agent --> joke_factory["joke_factory (tool)"] + joke_factory --> joke_generation_agent + joke_generation_agent --> get_jokes["get_jokes (tool)"] + get_jokes --> http_request["HTTP request"] + http_request --> get_jokes + get_jokes --> joke_generation_agent + joke_generation_agent --> joke_factory + joke_factory --> joke_selection_agent + joke_selection_agent --> END +``` + +## Programmatic agent hand-off + +"Programmatic agent hand-off" refers to the scenario where multiple agents are called in succession, with application code and/or a human in the loop responsible for deciding which agent to call next. + +Here agents don't need to use the same deps. + +Here we show two agents used in succession, the first to find a flight and the second to extract the user's seat preference. + +```python {title="programmatic_handoff.py"} +from typing import Literal, Union + +from pydantic import BaseModel, Field +from rich.prompt import Prompt + +from pydantic_ai import Agent, RunContext +from pydantic_ai.messages import ModelMessage +from pydantic_ai.usage import Usage, UsageLimits + + +class FlightDetails(BaseModel): + flight_number: str + + +class Failed(BaseModel): + """Unable to find a satisfactory choice.""" + + +flight_search_agent = Agent[None, Union[FlightDetails, Failed]]( # (1)! + 'openai:gpt-4o', + result_type=Union[FlightDetails, Failed], # type: ignore + system_prompt=( + 'Use the "flight_search" tool to find a flight ' + 'from the given origin to the given destination.' + ), +) + + +@flight_search_agent.tool # (2)! +async def flight_search( + ctx: RunContext[None], origin: str, destination: str +) -> Union[FlightDetails, None]: + # in reality, this would call a flight search API or + # use a browser to scrape a flight search website + return FlightDetails(flight_number='AK456') + + +usage_limits = UsageLimits(request_limit=15) # (3)! + + +async def find_flight(usage: Usage) -> Union[FlightDetails, None]: # (4)! + message_history: Union[list[ModelMessage], None] = None + for _ in range(3): + prompt = Prompt.ask( + 'Where would you like to fly from and to?', + ) + result = await flight_search_agent.run( + prompt, + message_history=message_history, + usage=usage, + usage_limits=usage_limits, + ) + if isinstance(result.data, FlightDetails): + return result.data + else: + message_history = result.all_messages( + result_tool_return_content='Please try again.' + ) + + +class SeatPreference(BaseModel): + row: int = Field(ge=1, le=30) + seat: Literal['A', 'B', 'C', 'D', 'E', 'F'] + + +# This agent is responsible for extracting the user's seat selection +seat_preference_agent = Agent[None, Union[SeatPreference, Failed]]( # (5)! + 'openai:gpt-4o', + result_type=Union[SeatPreference, Failed], # type: ignore + system_prompt=( + "Extract the user's seat preference. " + 'Seats A and F are window seats. ' + 'Row 1 is the front row and has extra leg room. ' + 'Rows 14, and 20 also have extra leg room. ' + ), +) + + +async def find_seat(usage: Usage) -> SeatPreference: # (6)! + message_history: Union[list[ModelMessage], None] = None + while True: + answer = Prompt.ask('What seat would you like?') + + result = await seat_preference_agent.run( + answer, + message_history=message_history, + usage=usage, + usage_limits=usage_limits, + ) + if isinstance(result.data, SeatPreference): + return result.data + else: + print('Could not understand seat preference. Please try again.') + message_history = result.all_messages() + + +async def main(): # (7)! + usage: Usage = Usage() + + opt_flight_details = await find_flight(usage) + if opt_flight_details is not None: + print(f'Flight found: {opt_flight_details.flight_number}') + #> Flight found: AK456 + seat_preference = await find_seat(usage) + print(f'Seat preference: {seat_preference}') + #> Seat preference: row=1 seat='A' +``` + +1. Define the first agent, which finds a flight. We use an explicit type annotation until [PEP-747](https://peps.python.org/pep-0747/) lands, see [structured results](results.md#structured-result-validation). We use a union as the result type so the model can communicate if it's unable to find a satisfactory choice; internally, each member of the union will be registered as a separate tool. +2. Define a tool on the agent to find a flight. In this simple case we could dispense with the tool and just define the agent to return structured data, then search for a flight, but in more complex scenarios the tool would be necessary. +3. Define usage limits for the entire app. +4. Define a function to find a flight, which asks the user for their preferences and then calls the agent to find a flight. +5. As with `flight_search_agent` above, we use an explicit type annotation to define the agent. +6. Define a function to find the user's seat preference, which asks the user for their seat preference and then calls the agent to extract the seat preference. +7. Now that we've put our logic for running each agent into separate functions, our main app becomes very simple. + +_(This example is complete, it can be run "as is")_ + +The control flow for this example can be summarised as follows: + +```mermaid +graph TB + START --> ask_user_flight["ask user for flight"] + + subgraph find_flight + flight_search_agent --> ask_user_flight + ask_user_flight --> flight_search_agent + end + + flight_search_agent --> ask_user_seat["ask user for seat"] + flight_search_agent --> END + + subgraph find_seat + seat_preference_agent --> ask_user_seat + ask_user_seat --> seat_preference_agent + end + + seat_preference_agent --> END +``` + +## PydanticAI Graphs + +!!! example "Work in progress" + This is a work in progress and not yet documented, see [#528](https://github.com/pydantic/pydantic-ai/issues/528) and [#539](https://github.com/pydantic/pydantic-ai/issues/539) + +## Examples + +The following examples demonstrate how to use dependencies in PydanticAI: + +- [Flight booking](examples/flight-booking.md) diff --git a/examples/pydantic_ai_examples/flight_booking.py b/examples/pydantic_ai_examples/flight_booking.py new file mode 100644 index 00000000..209e2adf --- /dev/null +++ b/examples/pydantic_ai_examples/flight_booking.py @@ -0,0 +1,242 @@ +"""Example of a multi-agent flow where one agent delegates work to another. + +In this scenario, a group of agents work together to find flights for a user. +""" + +import datetime +from dataclasses import dataclass +from typing import Literal + +import logfire +from pydantic import BaseModel, Field +from rich.prompt import Prompt + +from pydantic_ai import Agent, ModelRetry, RunContext +from pydantic_ai.messages import ModelMessage +from pydantic_ai.usage import Usage, UsageLimits + +# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured +logfire.configure(send_to_logfire='if-token-present') + + +class FlightDetails(BaseModel): + """Details of the most suitable flight.""" + + flight_number: str + price: int + origin: str = Field(description='Three-letter airport code') + destination: str = Field(description='Three-letter airport code') + date: datetime.date + + +class NoFlightFound(BaseModel): + """When no valid flight is found.""" + + +@dataclass +class Deps: + web_page_text: str + req_origin: str + req_destination: str + req_date: datetime.date + + +# This agent is responsible for controlling the flow of the conversation. +search_agent = Agent[Deps, FlightDetails | NoFlightFound]( + 'openai:gpt-4o', + result_type=FlightDetails | NoFlightFound, # type: ignore + retries=4, + system_prompt=( + 'Your job is to find the cheapest flight for the user on the given date. ' + ), +) + + +# This agent is responsible for extracting flight details from web page text. +extraction_agent = Agent( + 'openai:gpt-4o', + result_type=list[FlightDetails], + system_prompt='Extract all the flight details from the given text.', +) + + +@search_agent.tool +async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]: + """Get details of all flights.""" + # we pass the usage to the search agent so requests within this agent are counted + result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage) + logfire.info('found {flight_count} flights', flight_count=len(result.data)) + return result.data + + +@search_agent.result_validator +async def validate_result( + ctx: RunContext[Deps], result: FlightDetails | NoFlightFound +) -> FlightDetails | NoFlightFound: + """Procedural validation that the flight meets the constraints.""" + if isinstance(result, NoFlightFound): + return result + + errors: list[str] = [] + if result.origin != ctx.deps.req_origin: + errors.append( + f'Flight should have origin {ctx.deps.req_origin}, not {result.origin}' + ) + if result.destination != ctx.deps.req_destination: + errors.append( + f'Flight should have destination {ctx.deps.req_destination}, not {result.destination}' + ) + if result.date != ctx.deps.req_date: + errors.append(f'Flight should be on {ctx.deps.req_date}, not {result.date}') + + if errors: + raise ModelRetry('\n'.join(errors)) + else: + return result + + +class SeatPreference(BaseModel): + row: int = Field(ge=1, le=30) + seat: Literal['A', 'B', 'C', 'D', 'E', 'F'] + + +class Failed(BaseModel): + """Unable to extract a seat selection.""" + + +# This agent is responsible for extracting the user's seat selection +seat_preference_agent = Agent[ + None, SeatPreference | Failed +]( + 'openai:gpt-4o', + result_type=SeatPreference | Failed, # type: ignore + system_prompt=( + "Extract the user's seat preference. " + 'Seats A and F are window seats. ' + 'Row 1 is the front row and has extra leg room. ' + 'Rows 14, and 20 also have extra leg room. ' + ), +) + + +# in reality this would be downloaded from a booking site, +# potentially using another agent to navigate the site +flights_web_page = """ +1. Flight SFO-AK123 +- Price: $350 +- Origin: San Francisco International Airport (SFO) +- Destination: Ted Stevens Anchorage International Airport (ANC) +- Date: January 10, 2025 + +2. Flight SFO-AK456 +- Price: $370 +- Origin: San Francisco International Airport (SFO) +- Destination: Fairbanks International Airport (FAI) +- Date: January 10, 2025 + +3. Flight SFO-AK789 +- Price: $400 +- Origin: San Francisco International Airport (SFO) +- Destination: Juneau International Airport (JNU) +- Date: January 20, 2025 + +4. Flight NYC-LA101 +- Price: $250 +- Origin: San Francisco International Airport (SFO) +- Destination: Ted Stevens Anchorage International Airport (ANC) +- Date: January 10, 2025 + +5. Flight CHI-MIA202 +- Price: $200 +- Origin: Chicago O'Hare International Airport (ORD) +- Destination: Miami International Airport (MIA) +- Date: January 12, 2025 + +6. Flight BOS-SEA303 +- Price: $120 +- Origin: Boston Logan International Airport (BOS) +- Destination: Ted Stevens Anchorage International Airport (ANC) +- Date: January 12, 2025 + +7. Flight DFW-DEN404 +- Price: $150 +- Origin: Dallas/Fort Worth International Airport (DFW) +- Destination: Denver International Airport (DEN) +- Date: January 10, 2025 + +8. Flight ATL-HOU505 +- Price: $180 +- Origin: Hartsfield-Jackson Atlanta International Airport (ATL) +- Destination: George Bush Intercontinental Airport (IAH) +- Date: January 10, 2025 +""" + +# restrict how many requests this app can make to the LLM +usage_limits = UsageLimits(request_limit=15) + + +async def main(): + deps = Deps( + web_page_text=flights_web_page, + req_origin='SFO', + req_destination='ANC', + req_date=datetime.date(2025, 1, 10), + ) + message_history: list[ModelMessage] | None = None + usage: Usage = Usage() + # run the agent until a satisfactory flight is found + while True: + result = await search_agent.run( + f'Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}', + deps=deps, + usage=usage, + message_history=message_history, + usage_limits=usage_limits, + ) + if isinstance(result.data, NoFlightFound): + print('No flight found') + break + else: + flight = result.data + print(f'Flight found: {flight}') + answer = Prompt.ask( + 'Do you want to buy this flight, or keep searching? (buy/*search)', + choices=['buy', 'search', ''], + show_choices=False, + ) + if answer == 'buy': + seat = await find_seat(usage) + await buy_tickets(flight, seat) + break + else: + message_history = result.all_messages( + result_tool_return_content='Please suggest another flight' + ) + + +async def find_seat(usage: Usage) -> SeatPreference: + message_history: list[ModelMessage] | None = None + while True: + answer = Prompt.ask('What seat would you like?') + + result = await seat_preference_agent.run( + answer, + message_history=message_history, + usage=usage, + usage_limits=usage_limits, + ) + if isinstance(result.data, SeatPreference): + return result.data + else: + print('Could not understand seat preference. Please try again.') + message_history = result.all_messages() + + +async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference): + print(f'Purchasing flight {flight_details=!r} {seat=!r}...') + + +if __name__ == '__main__': + import asyncio + + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 65ffd155..e60e3a97 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -23,12 +23,14 @@ nav: - message-history.md - testing-evals.md - logfire.md + - multi-agent-applications.md - Examples: - examples/index.md - examples/pydantic-model.md - examples/weather-agent.md - examples/bank-support.md - examples/sql-gen.md + - examples/flight-booking.md - examples/rag.md - examples/stream-markdown.md - examples/stream-whales.md @@ -40,6 +42,7 @@ nav: - api/messages.md - api/exceptions.md - api/settings.md + - api/usage.md - api/models/base.md - api/models/openai.md - api/models/anthropic.md diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 1e5b60b8..b4af5db7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -20,9 +20,10 @@ messages as _messages, models, result, + usage as _usage, ) from .result import ResultData -from .settings import ModelSettings, UsageLimits, merge_model_settings +from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDeps, RunContext, @@ -192,8 +193,8 @@ async def run( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, - usage_limits: UsageLimits | None = None, - usage: result.Usage | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -236,7 +237,7 @@ async def run( model_name=model_used.name(), agent_name=self.name or 'agent', ) as run_span: - run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt) + run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt) messages = await self._prepare_messages(user_prompt, message_history, run_context) run_context.messages = messages @@ -244,7 +245,7 @@ async def run( tool.current_retry = 0 model_settings = merge_model_settings(self.model_settings, model_settings) - usage_limits = usage_limits or UsageLimits() + usage_limits = usage_limits or _usage.UsageLimits() while True: usage_limits.check_before_request(run_context.usage) @@ -272,11 +273,14 @@ async def run( # Check if we got a final result if final_result is not None: result_data = final_result.data + result_tool_name = final_result.tool_name run_span.set_attribute('all_messages', messages) run_span.set_attribute('usage', run_context.usage) handle_span.set_attribute('result', result_data) handle_span.message = 'handle model response -> final result' - return result.RunResult(messages, new_message_index, result_data, run_context.usage) + return result.RunResult( + messages, new_message_index, result_data, result_tool_name, run_context.usage + ) else: # continue the conversation handle_span.set_attribute('tool_responses', tool_responses) @@ -291,8 +295,8 @@ def run_sync( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, - usage_limits: UsageLimits | None = None, - usage: result.Usage | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -349,8 +353,8 @@ async def run_stream( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, - usage_limits: UsageLimits | None = None, - usage: result.Usage | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, infer_name: bool = True, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -396,7 +400,7 @@ async def main(): model_name=model_used.name(), agent_name=self.name or 'agent', ) as run_span: - run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt) + run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt) messages = await self._prepare_messages(user_prompt, message_history, run_context) run_context.messages = messages @@ -404,7 +408,7 @@ async def main(): tool.current_retry = 0 model_settings = merge_model_settings(self.model_settings, model_settings) - usage_limits = usage_limits or UsageLimits() + usage_limits = usage_limits or _usage.UsageLimits() while True: run_context.run_step += 1 diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 14606ea7..74cb9cc0 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Awaitable, Callable -from copy import copy +from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime from typing import Generic, Union, cast @@ -11,16 +11,10 @@ from typing_extensions import TypeVar from . import _result, _utils, exceptions, messages as _messages, models -from .settings import UsageLimits from .tools import AgentDeps, RunContext +from .usage import Usage, UsageLimits -__all__ = ( - 'ResultData', - 'ResultValidatorFunc', - 'Usage', - 'RunResult', - 'StreamedRunResult', -) +__all__ = 'ResultData', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult' ResultData = TypeVar('ResultData', default=str) @@ -44,55 +38,6 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') -@dataclass -class Usage: - """LLM usage associated with a request or run. - - Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests. - - You'll need to look up the documentation of the model you're using to convert usage to monetary costs. - """ - - requests: int = 0 - """Number of requests made to the LLM API.""" - request_tokens: int | None = None - """Tokens used in processing requests.""" - response_tokens: int | None = None - """Tokens used in generating responses.""" - total_tokens: int | None = None - """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" - details: dict[str, int] | None = None - """Any extra details returned by the model.""" - - def incr(self, incr_usage: Usage, *, requests: int = 0) -> None: - """Increment the usage in place. - - Args: - incr_usage: The usage to increment by. - requests: The number of requests to increment by in addition to `incr_usage.requests`. - """ - self.requests += requests - for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': - self_value = getattr(self, f) - other_value = getattr(incr_usage, f) - if self_value is not None or other_value is not None: - setattr(self, f, (self_value or 0) + (other_value or 0)) - - if incr_usage.details: - self.details = self.details or {} - for key, value in incr_usage.details.items(): - self.details[key] = self.details.get(key, 0) + value - - def __add__(self, other: Usage) -> Usage: - """Add two Usages together. - - This is provided so it's trivial to sum usage information from multiple requests and runs. - """ - new_usage = copy(self) - new_usage.incr(other) - return new_usage - - @dataclass class _BaseRunResult(ABC, Generic[ResultData]): """Base type for results. @@ -103,25 +48,70 @@ class _BaseRunResult(ABC, Generic[ResultData]): _all_messages: list[_messages.ModelMessage] _new_message_index: int - def all_messages(self) -> list[_messages.ModelMessage]: - """Return the history of _messages.""" + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + """Return the history of _messages. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + List of messages. + """ # this is a method to be consistent with the other methods + if result_tool_return_content is not None: + raise NotImplementedError('Setting result tool return content is not supported for this result type.') return self._all_messages - def all_messages_json(self) -> bytes: - """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes.""" - return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages()) + def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.all_messages(result_tool_return_content=result_tool_return_content) + ) - def new_messages(self) -> list[_messages.ModelMessage]: + def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return new messages associated with this run. - System prompts and any messages from older runs are excluded. + Messages from older runs are excluded. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + List of new messages. """ - return self.all_messages()[self._new_message_index :] + return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] + + def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes. - def new_messages_json(self) -> bytes: - """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes.""" - return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages()) + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the new messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.new_messages(result_tool_return_content=result_tool_return_content) + ) @abstractmethod def usage(self) -> Usage: @@ -134,12 +124,45 @@ class RunResult(_BaseRunResult[ResultData]): data: ResultData """Data from the final response in the run.""" + _result_tool_name: str | None _usage: Usage def usage(self) -> Usage: """Return the usage of the whole run.""" return self._usage + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + """Return the history of _messages. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + List of messages. + """ + if result_tool_return_content is not None: + return self._set_result_tool_return(result_tool_return_content) + else: + return self._all_messages + + def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: + """Set return content for the result tool. + + Useful if you want to continue the conversation and want to set the response to the result tool call. + """ + if not self._result_tool_name: + raise ValueError('Cannot set result tool return content when the return type is `str`.') + messages = deepcopy(self._all_messages) + last_message = messages[-1] + for part in last_message.parts: + if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name: + part.content = return_content + return messages + raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.') + @dataclass class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]): diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index 9fcadf5a..d6728948 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -1,15 +1,12 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING from httpx import Timeout from typing_extensions import TypedDict -from .exceptions import UsageLimitExceeded - if TYPE_CHECKING: - from .result import Usage + pass class ModelSettings(TypedDict, total=False): @@ -82,60 +79,3 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | return base | overrides else: return base or overrides - - -@dataclass -class UsageLimits: - """Limits on model usage. - - The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model. - Token counts are provided in responses from the model, and the token limits are checked after each response. - - Each of the limits can be set to `None` to disable that limit. - """ - - request_limit: int | None = 50 - """The maximum number of requests allowed to the model.""" - request_tokens_limit: int | None = None - """The maximum number of tokens allowed in requests to the model.""" - response_tokens_limit: int | None = None - """The maximum number of tokens allowed in responses from the model.""" - total_tokens_limit: int | None = None - """The maximum number of tokens allowed in requests and responses combined.""" - - def has_token_limits(self) -> bool: - """Returns `True` if this instance places any limits on token counts. - - If this returns `False`, the `check_tokens` method will never raise an error. - - This is useful because if we have token limits, we need to check them after receiving each streamed message. - If there are no limits, we can skip that processing in the streaming response iterator. - """ - return any( - limit is not None - for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit) - ) - - def check_before_request(self, usage: Usage) -> None: - """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit.""" - request_limit = self.request_limit - if request_limit is not None and usage.requests >= request_limit: - raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}') - - def check_tokens(self, usage: Usage) -> None: - """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" - request_tokens = usage.request_tokens or 0 - if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: - raise UsageLimitExceeded( - f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' - ) - - response_tokens = usage.response_tokens or 0 - if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit: - raise UsageLimitExceeded( - f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' - ) - - total_tokens = usage.total_tokens or 0 - if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: - raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py new file mode 100644 index 00000000..054be4e3 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -0,0 +1,114 @@ +from __future__ import annotations as _annotations + +from copy import copy +from dataclasses import dataclass + +from .exceptions import UsageLimitExceeded + +__all__ = 'Usage', 'UsageLimits' + + +@dataclass +class Usage: + """LLM usage associated with a request or run. + + Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests. + + You'll need to look up the documentation of the model you're using to convert usage to monetary costs. + """ + + requests: int = 0 + """Number of requests made to the LLM API.""" + request_tokens: int | None = None + """Tokens used in processing requests.""" + response_tokens: int | None = None + """Tokens used in generating responses.""" + total_tokens: int | None = None + """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" + details: dict[str, int] | None = None + """Any extra details returned by the model.""" + + def incr(self, incr_usage: Usage, *, requests: int = 0) -> None: + """Increment the usage in place. + + Args: + incr_usage: The usage to increment by. + requests: The number of requests to increment by in addition to `incr_usage.requests`. + """ + self.requests += requests + for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': + self_value = getattr(self, f) + other_value = getattr(incr_usage, f) + if self_value is not None or other_value is not None: + setattr(self, f, (self_value or 0) + (other_value or 0)) + + if incr_usage.details: + self.details = self.details or {} + for key, value in incr_usage.details.items(): + self.details[key] = self.details.get(key, 0) + value + + def __add__(self, other: Usage) -> Usage: + """Add two Usages together. + + This is provided so it's trivial to sum usage information from multiple requests and runs. + """ + new_usage = copy(self) + new_usage.incr(other) + return new_usage + + +@dataclass +class UsageLimits: + """Limits on model usage. + + The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model. + Token counts are provided in responses from the model, and the token limits are checked after each response. + + Each of the limits can be set to `None` to disable that limit. + """ + + request_limit: int | None = 50 + """The maximum number of requests allowed to the model.""" + request_tokens_limit: int | None = None + """The maximum number of tokens allowed in requests to the model.""" + response_tokens_limit: int | None = None + """The maximum number of tokens allowed in responses from the model.""" + total_tokens_limit: int | None = None + """The maximum number of tokens allowed in requests and responses combined.""" + + def has_token_limits(self) -> bool: + """Returns `True` if this instance places any limits on token counts. + + If this returns `False`, the `check_tokens` method will never raise an error. + + This is useful because if we have token limits, we need to check them after receiving each streamed message. + If there are no limits, we can skip that processing in the streaming response iterator. + """ + return any( + limit is not None + for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit) + ) + + def check_before_request(self, usage: Usage) -> None: + """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit.""" + request_limit = self.request_limit + if request_limit is not None and usage.requests >= request_limit: + raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}') + + def check_tokens(self, usage: Usage) -> None: + """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" + request_tokens = usage.request_tokens or 0 + if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: + raise UsageLimitExceeded( + f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' + ) + + response_tokens = usage.response_tokens or 0 + if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit: + raise UsageLimitExceeded( + f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' + ) + + total_tokens = usage.total_tokens or 0 + if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: + raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') diff --git a/tests/test_agent.py b/tests/test_agent.py index 6b232087..277eeeb5 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,4 +1,5 @@ import json +import re import sys from datetime import timezone from typing import Any, Callable, Union @@ -227,7 +228,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo: ) -def test_plain_response(set_event_loop: None): +def test_plain_response_then_tuple(set_event_loop: None): call_index = 0 def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: @@ -271,6 +272,42 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ] ) + assert result._result_tool_name == 'final_result' # pyright: ignore[reportPrivateUsage] + assert result.all_messages(result_tool_return_content='foobar')[-1] == snapshot( + ModelRequest( + parts=[ToolReturnPart(tool_name='final_result', content='foobar', timestamp=IsNow(tz=timezone.utc))] + ) + ) + assert result.all_messages()[-1] == snapshot( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) + ) + ] + ) + ) + + +def test_result_tool_return_content_str_return(set_event_loop: None): + agent = Agent('test') + + result = agent.run_sync('Hello') + assert result.data == 'success (no tool calls)' + + msg = re.escape('Cannot set result tool return content when the return type is `str`.') + with pytest.raises(ValueError, match=msg): + result.all_messages(result_tool_return_content='foobar') + + +def test_result_tool_return_content_no_tool(set_event_loop: None): + agent = Agent('test', result_type=int) + + result = agent.run_sync('Hello') + assert result.data == 0 + result._result_tool_name = 'wrong' # pyright: ignore[reportPrivateUsage] + with pytest.raises(LookupError, match=re.escape("No tool call found with tool name 'wrong'.")): + result.all_messages(result_tool_return_content='foobar') def test_response_tuple(set_event_loop: None): @@ -507,6 +544,7 @@ async def ret_a(x: str) -> str: ], _new_message_index=4, data='{"ret_a":"a-apple"}', + _result_tool_name=None, _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), ) ) @@ -549,6 +587,7 @@ async def ret_a(x: str) -> str: ], _new_message_index=4, data='{"ret_a":"a-apple"}', + _result_tool_name=None, _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), ) ) @@ -648,6 +687,7 @@ async def ret_a(x: str) -> str: ), ], _new_message_index=5, + _result_tool_name='final_result', _usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None), ) ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 6f0e7912..9f296999 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -72,6 +72,7 @@ def test_docs_examples( mocker.patch('httpx.AsyncClient.get', side_effect=async_http_request) mocker.patch('httpx.AsyncClient.post', side_effect=async_http_request) mocker.patch('random.randint', return_value=4) + mocker.patch('rich.prompt.Prompt.ask', side_effect=rich_prompt_ask) env.set('OPENAI_API_KEY', 'testing') env.set('GEMINI_API_KEY', 'testing') @@ -145,6 +146,14 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: return http_request(url, **kwargs) +def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: + if prompt == 'Where would you like to fly from and to?': + return 'SFO to ANC' + else: + assert prompt == 'What seat would you like?', prompt + return 'window seat with leg room' + + text_responses: dict[str, str | ToolCallPart] = { 'What is the weather like in West London and in Wiltshire?': ( 'The weather in West London is raining, while in Wiltshire it is sunny.' @@ -218,21 +227,36 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: 'Rome is known for its rich history, stunning architecture, and delicious cuisine.' ), 'Begin infinite retry loop!': ToolCallPart(tool_name='infinite_retry_tool', args=ArgsDict({})), + 'Please generate 5 jokes.': ToolCallPart( + tool_name='final_result', + args=ArgsDict({'response': []}), + ), + 'SFO to ANC': ToolCallPart( + tool_name='flight_search', + args=ArgsDict({'origin': 'SFO', 'destination': 'ANC'}), + ), + 'window seat with leg room': ToolCallPart( + tool_name='final_result_SeatPreference', + args=ArgsDict({'row': 1, 'seat': 'A'}), + ), } -async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover +async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover # noqa: C901 m = messages[-1].parts[-1] if isinstance(m, UserPromptPart): - if response := text_responses.get(m.content): + if m.content == 'Tell me a joke.' and any(t.name == 'joke_factory' for t in info.function_tools): + return ModelResponse(parts=[ToolCallPart(tool_name='joke_factory', args=ArgsDict({'count': 5}))]) + elif m.content == 'Please generate 5 jokes.' and any(t.name == 'get_jokes' for t in info.function_tools): + return ModelResponse(parts=[ToolCallPart(tool_name='get_jokes', args=ArgsDict({'count': 5}))]) + elif re.fullmatch(r'sql prompt \d+', m.content): + return ModelResponse.from_text(content='SELECT 1') + elif response := text_responses.get(m.content): if isinstance(response, str): return ModelResponse.from_text(content=response) else: return ModelResponse(parts=[response]) - if re.fullmatch(r'sql prompt \d+', m.content): - return ModelResponse.from_text(content='SELECT 1') - elif isinstance(m, ToolReturnPart) and m.tool_name == 'roulette_wheel': win = m.content == 'winner' return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args=ArgsDict({'response': win}))]) @@ -249,7 +273,7 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes elif isinstance(m, RetryPromptPart) and m.tool_name == 'infinite_retry_tool': return ModelResponse(parts=[ToolCallPart(tool_name='infinite_retry_tool', args=ArgsDict({}))]) elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_user_by_name': - args = { + args: dict[str, Any] = { 'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!', 'user_id': 123, } @@ -263,6 +287,14 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes 'risk': 1, } return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args))]) + elif isinstance(m, ToolReturnPart) and m.tool_name == 'joke_factory': + return ModelResponse.from_text(content='Did you hear about the toothpaste scandal? They called it Colgate.') + elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_jokes': + args = {'response': []} + return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args))]) + elif isinstance(m, ToolReturnPart) and m.tool_name == 'flight_search': + args = {'flight_number': m.content.flight_number} # type: ignore + return ModelResponse(parts=[ToolCallPart(tool_name='final_result_FlightDetails', args=ArgsDict(args))]) else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 37b45142..cbbfec34 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -16,8 +16,7 @@ UserPromptPart, ) from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Usage -from pydantic_ai.settings import UsageLimits +from pydantic_ai.usage import Usage, UsageLimits from .conftest import IsNow