From e51dbee59c05e836cd9e4576641babbee73d1fb3 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 22 Oct 2024 04:14:29 +0100 Subject: [PATCH 1/6] improving examples code and docs --- Makefile | 9 +- examples/README.md | 67 ++++++++++++++ examples/parse_model.py | 15 --- examples/pydantic_model.py | 31 +++++++ examples/sql_gen.py | 20 +++- examples/weather.py | 138 ++++++++++++++++++++++------ pydantic_ai/agent.py | 13 ++- pyproject.toml | 5 + tests/models/test_model_function.py | 18 ++-- uv.lock | 31 +++++++ 10 files changed, 283 insertions(+), 64 deletions(-) create mode 100644 examples/README.md delete mode 100644 examples/parse_model.py create mode 100644 examples/pydantic_model.py diff --git a/Makefile b/Makefile index 8f9a4b7a..01b3f908 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,4 @@ .DEFAULT_GOAL := all -sources = pydantic_ai tests .PHONY: .uv # Check that uv is installed .uv: @@ -16,13 +15,13 @@ install: .uv .pre-commit .PHONY: format # Format the code format: - uv run ruff format $(sources) - uv run ruff check --fix --fix-only $(sources) + uv run ruff format + uv run ruff check --fix --fix-only .PHONY: lint # Lint the code lint: - uv run ruff format --check $(sources) - uv run ruff check $(sources) + uv run ruff format --check + uv run ruff check .PHONY: typecheck-pyright typecheck-pyright: diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..f8fcdee8 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,67 @@ +# Pydantic AI Examples + +Examples of how to use Pydantic AI and what it can do. + +## Usage + +To run the examples, run: + +```bash +uv run -m examples. +``` + +## Examples + +### `pydantic_model` + +Simple example of using Pydantic AI to construct a Pydantic model from a text input. + +```bash +uv run -m examples.pydantic_model +``` + +This examples uses `openai:gpt-4o` by default but it works well with other modesl, e.g. you can run it +with Gemini using: + +```bash +PYDANTIC_AI_MODEL=gemini-1.5-pro uv run -m examples.pydantic_model +``` + +or + +```bash +PYDANTIC_AI_MODEL=gemini-1.5-flash uv run -m examples.pydantic_model +``` + +```bash + +### `sql_gen` + +Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. + +```bash +uv run -m examples.sql_gen +``` + +This model uses `gemini-1.5-flash` by default since Gemini is good at single shot queries. + +### `weather` + +Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question. + +In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities, +the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use +the `get_weather` tool to get the weather. + +To run this example properly, you'll need two extra API keys: +* A weather API key from [tomorrow.io](https://www.tomorrow.io/weather-api/) set via `WEATHER_API_KEY` +* A geocoding API key from [geocode.maps.co](https://geocode.maps.co/) set via `GEO_API_KEY` + +**(Note if either key is missing, the code will fall back to dummy data.)** + +```bash +uv run -m examples.weather +``` + +This example uses `openai:gpt-4o` by default. Gemini seems to be unable to handle the multiple tool +calls. diff --git a/examples/parse_model.py b/examples/parse_model.py deleted file mode 100644 index a389ce9d..00000000 --- a/examples/parse_model.py +++ /dev/null @@ -1,15 +0,0 @@ -from pydantic import BaseModel - -from pydantic_ai import Agent - - -class MyModel(BaseModel): - city: str - country: str - - -agent = Agent('openai:gpt-4o', result_type=MyModel, deps=None) - -if __name__ == '__main__': - result = agent.run_sync('The windy city in the US of A.') - print(result.response) diff --git a/examples/pydantic_model.py b/examples/pydantic_model.py new file mode 100644 index 00000000..0714e611 --- /dev/null +++ b/examples/pydantic_model.py @@ -0,0 +1,31 @@ +"""Simple example of using Pydantic AI to construct a Pydantic model from a text input. + +Run with: + + uv run -m examples.pydantic_model +""" + +import os +from typing import cast + +# if you don't want to use logfire, just comment out these lines +import logfire +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.agent import KnownModelName + +logfire.configure() + + +class MyModel(BaseModel): + city: str + country: str + + +model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o')) +agent = Agent(model, result_type=MyModel, deps=None) + +if __name__ == '__main__': + result = agent.run_sync('The windy city in the US of A.') + print(result.response) diff --git a/examples/sql_gen.py b/examples/sql_gen.py index c33258b9..3f254042 100644 --- a/examples/sql_gen.py +++ b/examples/sql_gen.py @@ -1,8 +1,23 @@ +"""Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. + +Run with: + + uv run -m examples.sql_gen +""" + +import os from dataclasses import dataclass +from typing import cast -from pydantic_ai import Agent +# if you don't want to use logfire, just comment out these lines +import logfire from devtools import debug +from pydantic_ai import Agent +from pydantic_ai.agent import KnownModelName + +logfire.configure() + system_prompt = """\ Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request. @@ -54,7 +69,8 @@ class Response: sql_query: str -agent = Agent('gemini-1.5-flash', result_type=Response, system_prompt=system_prompt, deps=None) +model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'gemini-1.5-flash')) +agent = Agent(model, result_type=Response, system_prompt=system_prompt, deps=None, retries=2) if __name__ == '__main__': diff --git a/examples/weather.py b/examples/weather.py index 9c32a24c..91bb1a5e 100644 --- a/examples/weather.py +++ b/examples/weather.py @@ -1,49 +1,131 @@ +"""Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question. + +In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities, +the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use +the `get_weather` tool to get the weather. + +Run with: + + uv run -m examples.weather +""" + +import asyncio +import os +from dataclasses import dataclass +from typing import Any, cast + +# if you don't want to use logfire, just comment out these lines +import logfire from devtools import debug +from httpx import AsyncClient -from pydantic_ai import Agent +from pydantic_ai import Agent, CallContext, ModelRetry +from pydantic_ai.agent import KnownModelName -try: - import logfire -except ImportError: - pass -else: - logfire.configure() +logfire.configure() -weather_agent: Agent[None, str] = Agent('openai:gpt-4o', system_prompt='Be concise, reply with one sentence.') +@dataclass +class Deps: + client: AsyncClient + weather_api_key: str | None + geo_api_key: str | None -@weather_agent.retriever_plain -async def get_lat_lng(location_description: str) -> dict[str, float]: - """ - Get the latitude and longitude of a location. + +model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o')) +weather_agent: Agent[Deps, str] = Agent(model, system_prompt='Be concise, reply with one sentence.', retries=2) + + +@weather_agent.retriever_context +async def get_lat_lng(ctx: CallContext[Deps], location_description: str) -> dict[str, float]: + """Get the latitude and longitude of a location. Args: + ctx: The context. location_description: A description of a location. """ - if 'london' in location_description.lower(): + if ctx.deps.geo_api_key is None: + # if no API key is provided, return a dummy response (London) return {'lat': 51.1, 'lng': -0.1} - elif 'wiltshire' in location_description.lower(): - return {'lat': 51.1, 'lng': -2.11} - else: - return {'lat': 0, 'lng': 0} + params = { + 'q': location_description, + 'api_key': ctx.deps.geo_api_key, + } + r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params) + r.raise_for_status() + data = r.json() + if not data: + raise ModelRetry('Could not find the location') + return {'lat': data[0]['lat'], 'lng': data[0]['lon']} -@weather_agent.retriever_plain -async def get_whether(lat: float, lng: float) -> str: - """ - Get the weather at a location. + +@weather_agent.retriever_context +async def get_weather(ctx: CallContext[Deps], lat: float, lng: float) -> dict[str, Any]: + """Get the weather at a location. Args: + ctx: The context. lat: Latitude of the location. lng: Longitude of the location. """ - if abs(lat - 51.1) < 0.1 and abs(lng + 0.1) < 0.1: - # it always rains in London - return 'Raining' - else: - return 'Sunny' + if ctx.deps.weather_api_key is None: + # if no API key is provided, return a dummy response + return {'temperature': '21 °C', 'description': 'Sunny'} + + params = { + 'apikey': ctx.deps.weather_api_key, + 'location': f'{lat},{lng}', + 'units': 'metric', + } + r = await ctx.deps.client.get('https://api.tomorrow.io/v4/weather/realtime', params=params) + r.raise_for_status() + data = r.json() + values = data['data']['values'] + # https://docs.tomorrow.io/reference/data-layers-weather-codes + code_lookup = { + 1000: 'Clear, Sunny', + 1100: 'Mostly Clear', + 1101: 'Partly Cloudy', + 1102: 'Mostly Cloudy', + 1001: 'Cloudy', + 2000: 'Fog', + 2100: 'Light Fog', + 4000: 'Drizzle', + 4001: 'Rain', + 4200: 'Light Rain', + 4201: 'Heavy Rain', + 5000: 'Snow', + 5001: 'Flurries', + 5100: 'Light Snow', + 5101: 'Heavy Snow', + 6000: 'Freezing Drizzle', + 6001: 'Freezing Rain', + 6200: 'Light Freezing Rain', + 6201: 'Heavy Freezing Rain', + 7000: 'Ice Pellets', + 7101: 'Heavy Ice Pellets', + 7102: 'Light Ice Pellets', + 8000: 'Thunderstorm', + } + return { + 'temperature': f'{values['temperatureApparent']:0.0f}°C', + 'description': code_lookup.get(values['weatherCode'], 'Unknown'), + } + + +async def main(): + async with AsyncClient() as client: + logfire.instrument_httpx() + # create a free API key at https://www.tomorrow.io/weather-api/ + weather_api_key = os.getenv('WEATHER_API_KEY') + # create a free API key at https://geocode.maps.co/ + geo_api_key = os.getenv('GEO_API_KEY') + deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key) + result = await weather_agent.run('What is the weather like in London and in Wiltshire?', deps=deps) + debug(result) + print('Response:', result.response) if __name__ == '__main__': - result = weather_agent.run_sync('What is the weather like in West London and in Wiltshire?') - debug(result) + asyncio.run(main()) diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index ff1f27e8..6ca71626 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Awaitable, Sequence from dataclasses import dataclass -from typing import Any, Callable, Generic, Literal, cast, overload +from typing import Any, Callable, Generic, Literal, cast, final, overload import logfire_api from pydantic import ValidationError @@ -12,18 +12,19 @@ from . import _result, _retriever as _r, _system_prompt, _utils, messages as _messages, models, shared from .shared import AgentDeps, ResultData -__all__ = ('Agent',) +__all__ = 'Agent', 'KnownModelName' KnownModelName = Literal[ 'openai:gpt-4o', 'openai:gpt-4-turbo', 'openai:gpt-4', 'openai:gpt-3.5-turbo', 'gemini-1.5-flash', 'gemini-1.5-pro' ] _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') +@final @dataclass(init=False) class Agent(Generic[AgentDeps, ResultData]): """Main class for creating "agents" - a way to have a specific type of "conversation" with an LLM.""" - # slots mostly for my sanity — knowing what attributes are available + # dataclass fields mostly for my sanity — knowing what attributes are available model: models.Model | None _result_schema: _result.ResultSchema[ResultData] | None _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] @@ -86,9 +87,10 @@ async def run( The result of the run. """ if model is not None: - model_ = models.infer_model(model) + custom_model = model_ = models.infer_model(model) elif self.model is not None: model_ = self.model + custom_model = None else: raise shared.UserError('`model` must be set either when creating the agent or when calling it.') @@ -111,7 +113,7 @@ async def run( cost = shared.Cost() with _logfire.span( - 'agent run {prompt=}', prompt=user_prompt, agent=self, model=model_, model_name=model_.name() + 'agent run {prompt=}', prompt=user_prompt, agent=self, custom_model=custom_model, model_name=model_.name() ) as run_span: try: while True: @@ -279,6 +281,7 @@ async def _handle_model_response( for call in model_response.calls: retriever = self._retrievers.get(call.tool_name) if retriever is None: + # should this be a retry error? raise shared.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}') coros.append(retriever.run(deps, call)) new_messages = await asyncio.gather(*coros) diff --git a/pyproject.toml b/pyproject.toml index c1e995ba..8b86f2d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev-dependencies = [ "coverage[toml]>=7.6.2", "devtools>=0.12.2", "anyio>=4.5.0", + "logfire[httpx]>=1.2.0", ] [tool.hatch.build.targets.wheel] @@ -106,6 +107,7 @@ quote-style = "single" [tool.ruff.lint.per-file-ignores] "tests/**.py" = ["D"] +"examples/**.py" = ["D"] [tool.pyright] typeCheckingMode = "strict" @@ -148,3 +150,6 @@ exclude_lines = [ '$\s*assert_never\(', 'if __name__ == .__main__.:', ] + +[tool.logfire] +ignore_no_config = true diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index fd620eab..5de77f56 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -76,9 +76,9 @@ def test_simple(): ) -def whether_model(messages: list[Message], info: AgentInfo) -> LLMMessage: # pragma: no cover +def weather_model(messages: list[Message], info: AgentInfo) -> LLMMessage: # pragma: no cover assert info.allow_text_result - assert info.retrievers.keys() == {'get_location', 'get_whether'} + assert info.retrievers.keys() == {'get_location', 'get_weather'} last = messages[-1] if last.role == 'user': return LLMToolCalls( @@ -91,15 +91,15 @@ def whether_model(messages: list[Message], info: AgentInfo) -> LLMMessage: # pr ) elif last.role == 'tool-return': if last.tool_name == 'get_location': - return LLMToolCalls(calls=[ToolCall.from_json('get_whether', last.model_response_str())]) - elif last.tool_name == 'get_whether': + return LLMToolCalls(calls=[ToolCall.from_json('get_weather', last.model_response_str())]) + elif last.tool_name == 'get_weather': location_name = next(m.content for m in messages if m.role == 'user') return LLMResponse(f'{last.content} in {location_name}') raise ValueError(f'Unexpected message: {last}') -weather_agent: Agent[None, str] = Agent(FunctionModel(whether_model)) +weather_agent: Agent[None, str] = Agent(FunctionModel(weather_model)) @weather_agent.retriever_plain @@ -112,7 +112,7 @@ async def get_location(location_description: str) -> str: @weather_agent.retriever_context -async def get_whether(_: CallContext[None], lat: int, lng: int): +async def get_weather(_: CallContext[None], lat: int, lng: int): if (lat, lng) == (51, 0): # it always rains in London return 'Raining' @@ -120,7 +120,7 @@ async def get_whether(_: CallContext[None], lat: int, lng: int): return 'Sunny' -def test_whether(): +def test_weather(): result = weather_agent.run_sync('London') assert result.response == 'Raining in London' assert result.message_history == snapshot( @@ -141,7 +141,7 @@ def test_whether(): LLMToolCalls( calls=[ ToolCall.from_json( - 'get_whether', + 'get_weather', '{"lat": 51, "lng": 0}', ) ], @@ -149,7 +149,7 @@ def test_whether(): role='llm-tool-calls', ), ToolReturn( - tool_name='get_whether', + tool_name='get_weather', content='Raining', timestamp=IsNow(), role='tool-return', diff --git a/uv.lock b/uv.lock index 32d90204..354042ef 100644 --- a/uv.lock +++ b/uv.lock @@ -532,6 +532,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/7f/37d9c3cbed1ef23b467c0c0039f35524595f8fd79f3acb54e647a0ccd590/logfire-1.2.0-py3-none-any.whl", hash = "sha256:edb2b441e418cf31877bd97e24b3755f873bb423f834cca66f315b25bde61ebd", size = 164724 }, ] +[package.optional-dependencies] +httpx = [ + { name = "opentelemetry-instrumentation-httpx" }, +] + [[package]] name = "logfire-api" version = "1.2.0" @@ -690,6 +695,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/7f/405c41d4f359121376c9d5117dcf68149b8122d3f6c718996d037bd4d800/opentelemetry_instrumentation-0.48b0-py3-none-any.whl", hash = "sha256:a69750dc4ba6a5c3eb67986a337185a25b739966d80479befe37b546fc870b44", size = 29449 }, ] +[[package]] +name = "opentelemetry-instrumentation-httpx" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/d9/c65d818607c16d1b7ea8d2de6111c6cecadf8d2fd38c1885a72733a7c6d3/opentelemetry_instrumentation_httpx-0.48b0.tar.gz", hash = "sha256:ee977479e10398931921fb995ac27ccdeea2e14e392cb27ef012fc549089b60a", size = 16931 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/fe/f2daa9d6d988c093b8c7b1d35df675761a8ece0b600b035dc04982746c9d/opentelemetry_instrumentation_httpx-0.48b0-py3-none-any.whl", hash = "sha256:d94f9d612c82d09fe22944d1904a30a464c19bea2ba76be656c99a28ad8be8e5", size = 13900 }, +] + [[package]] name = "opentelemetry-proto" version = "1.27.0" @@ -729,6 +749,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/7a/4f0063dbb0b6c971568291a8bc19a4ca70d3c185db2d956230dd67429dfc/opentelemetry_semantic_conventions-0.48b0-py3-none-any.whl", hash = "sha256:a0de9f45c413a8669788a38569c7e0a11ce6ce97861a628cca785deecdc32a1f", size = 149685 }, ] +[[package]] +name = "opentelemetry-util-http" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/d7/185c494754340e0a3928fd39fde2616ee78f2c9d66253affaad62d5b7935/opentelemetry_util_http-0.48b0.tar.gz", hash = "sha256:60312015153580cc20f322e5cdc3d3ecad80a71743235bdb77716e742814623c", size = 7863 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/2e/36097c0a4d0115b8c7e377c90bab7783ac183bc5cb4071308f8959454311/opentelemetry_util_http-0.48b0-py3-none-any.whl", hash = "sha256:76f598af93aab50328d2a69c786beaedc8b6a7770f7a818cc307eb353debfffb", size = 6946 }, +] + [[package]] name = "packaging" version = "24.1" @@ -820,6 +849,7 @@ dev = [ { name = "devtools" }, { name = "dirty-equals" }, { name = "inline-snapshot" }, + { name = "logfire", extra = ["httpx"] }, { name = "mypy" }, { name = "pyright" }, { name = "pytest" }, @@ -845,6 +875,7 @@ dev = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "dirty-equals", specifier = ">=0.8.0" }, { name = "inline-snapshot", specifier = ">=0.13.3" }, + { name = "logfire", extras = ["httpx"], specifier = ">=1.2.0" }, { name = "mypy", specifier = ">=1.11.2" }, { name = "pyright", specifier = ">=1.1.384" }, { name = "pytest", specifier = ">=8.3.3" }, From 16aa352fcbfd3f299a899d4050a4c625f541426f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 22 Oct 2024 04:36:34 +0100 Subject: [PATCH 2/6] fix example imports --- examples/pydantic_model.py | 5 +++-- examples/sql_gen.py | 5 +++-- examples/weather.py | 5 +++-- pyproject.toml | 3 ++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/pydantic_model.py b/examples/pydantic_model.py index 0714e611..309a47f0 100644 --- a/examples/pydantic_model.py +++ b/examples/pydantic_model.py @@ -8,13 +8,14 @@ import os from typing import cast -# if you don't want to use logfire, just comment out these lines -import logfire from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.agent import KnownModelName +# if you don't want to use logfire, just comment out these lines +import logfire + logfire.configure() diff --git a/examples/sql_gen.py b/examples/sql_gen.py index 3f254042..33a5719a 100644 --- a/examples/sql_gen.py +++ b/examples/sql_gen.py @@ -9,13 +9,14 @@ from dataclasses import dataclass from typing import cast -# if you don't want to use logfire, just comment out these lines -import logfire from devtools import debug from pydantic_ai import Agent from pydantic_ai.agent import KnownModelName +# if you don't want to use logfire, just comment out these lines +import logfire + logfire.configure() system_prompt = """\ diff --git a/examples/weather.py b/examples/weather.py index 91bb1a5e..8aa3b0fd 100644 --- a/examples/weather.py +++ b/examples/weather.py @@ -14,14 +14,15 @@ from dataclasses import dataclass from typing import Any, cast -# if you don't want to use logfire, just comment out these lines -import logfire from devtools import debug from httpx import AsyncClient from pydantic_ai import Agent, CallContext, ModelRetry from pydantic_ai.agent import KnownModelName +# if you don't want to use logfire, just comment out these lines +import logfire + logfire.configure() diff --git a/pyproject.toml b/pyproject.toml index 8b86f2d5..c24b96dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,8 @@ quote-style = "single" [tool.ruff.lint.per-file-ignores] "tests/**.py" = ["D"] -"examples/**.py" = ["D"] +# see https://github.com/astral-sh/ruff/issues/13871 +"examples/**.py" = ["D", "I001"] [tool.pyright] typeCheckingMode = "strict" From dbc2e90a37ad044fb3a2bc4af258263bb581cde4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 22 Oct 2024 04:41:02 +0100 Subject: [PATCH 3/6] tweak linting rules --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c24b96dd..5470d359 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,7 @@ quote-style = "single" [tool.ruff.lint.per-file-ignores] "tests/**.py" = ["D"] # see https://github.com/astral-sh/ruff/issues/13871 -"examples/**.py" = ["D", "I001"] +"examples/**.py" = ["D103", "I001"] [tool.pyright] typeCheckingMode = "strict" From 51373c462d13a39ff944f91047e2fdd1bdce44d1 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 22 Oct 2024 04:44:07 +0100 Subject: [PATCH 4/6] tweak examples readme --- examples/README.md | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/examples/README.md b/examples/README.md index f8fcdee8..d942b9a2 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,12 +7,12 @@ Examples of how to use Pydantic AI and what it can do. To run the examples, run: ```bash -uv run -m examples. +uv run -m examples. ``` ## Examples -### `pydantic_model` +### `pydantic_model.py` Simple example of using Pydantic AI to construct a Pydantic model from a text input. @@ -27,15 +27,9 @@ with Gemini using: PYDANTIC_AI_MODEL=gemini-1.5-pro uv run -m examples.pydantic_model ``` -or +(or `PYDANTIC_AI_MODEL=gemini-1.5-flash...`) -```bash -PYDANTIC_AI_MODEL=gemini-1.5-flash uv run -m examples.pydantic_model -``` - -```bash - -### `sql_gen` +### `sql_gen.py` Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. @@ -45,7 +39,7 @@ uv run -m examples.sql_gen This model uses `gemini-1.5-flash` by default since Gemini is good at single shot queries. -### `weather` +### `weather.py` Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question. From 4bd30dbc1d7e5451a6cb7f944a8a492346acf954 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 22 Oct 2024 06:18:11 +0100 Subject: [PATCH 5/6] more examples, switch ToolRetry to RetryPrompt --- examples/README.md | 20 ++++- examples/pydantic_model.py | 9 +- examples/sql_gen.py | 134 ++++++++++++++++++++-------- examples/weather.py | 32 ++++--- pydantic_ai/_result.py | 17 ++-- pydantic_ai/_retriever.py | 4 +- pydantic_ai/agent.py | 13 ++- pydantic_ai/messages.py | 8 +- pydantic_ai/models/gemini.py | 19 ++-- pydantic_ai/models/openai.py | 21 +++-- pydantic_ai/models/test.py | 4 +- pyproject.toml | 8 +- tests/models/test_gemini.py | 4 +- tests/models/test_model_function.py | 50 ++++++++++- tests/models/test_model_test.py | 4 +- tests/models/test_openai.py | 4 +- tests/test_agent.py | 8 +- uv.lock | 97 +++++++++++++------- 18 files changed, 316 insertions(+), 140 deletions(-) diff --git a/examples/README.md b/examples/README.md index d942b9a2..351c9fbc 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,33 +14,45 @@ uv run -m examples. ### `pydantic_model.py` +(Demonstrates: custom `result_type`) + Simple example of using Pydantic AI to construct a Pydantic model from a text input. ```bash -uv run -m examples.pydantic_model +uv run --extra examples -m examples.pydantic_model ``` This examples uses `openai:gpt-4o` by default but it works well with other modesl, e.g. you can run it with Gemini using: ```bash -PYDANTIC_AI_MODEL=gemini-1.5-pro uv run -m examples.pydantic_model +PYDANTIC_AI_MODEL=gemini-1.5-pro uv run --extra examples -m examples.pydantic_model ``` (or `PYDANTIC_AI_MODEL=gemini-1.5-flash...`) ### `sql_gen.py` +(Demonstrates: custom `result_type`, dynamic system prompt, result validation, agent deps) + Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. ```bash -uv run -m examples.sql_gen +uv run --extra examples -m examples.sql_gen +``` + +or to use a custom prompt: + +```bash +uv run --extra examples -m examples.sql_gen "find me whatever" ``` This model uses `gemini-1.5-flash` by default since Gemini is good at single shot queries. ### `weather.py` +(Demonstrates: retrievers, multiple retrievers, agent deps) + Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question. In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities, @@ -54,7 +66,7 @@ To run this example properly, you'll need two extra API keys: **(Note if either key is missing, the code will fall back to dummy data.)** ```bash -uv run -m examples.weather +uv run --extra examples -m examples.weather ``` This example uses `openai:gpt-4o` by default. Gemini seems to be unable to handle the multiple tool diff --git a/examples/pydantic_model.py b/examples/pydantic_model.py index 309a47f0..93b927f2 100644 --- a/examples/pydantic_model.py +++ b/examples/pydantic_model.py @@ -2,21 +2,20 @@ Run with: - uv run -m examples.pydantic_model + uv run --extra examples -m examples.pydantic_model """ import os from typing import cast +import logfire from pydantic import BaseModel from pydantic_ai import Agent from pydantic_ai.agent import KnownModelName -# if you don't want to use logfire, just comment out these lines -import logfire - -logfire.configure() +# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up +logfire.configure(send_to_logfire='if-token-present') class MyModel(BaseModel): diff --git a/examples/sql_gen.py b/examples/sql_gen.py index 33a5719a..ec68588a 100644 --- a/examples/sql_gen.py +++ b/examples/sql_gen.py @@ -2,60 +2,79 @@ Run with: - uv run -m examples.sql_gen + uv run --extra examples -m examples.sql_gen """ +import asyncio import os +import sys +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import cast +from datetime import date +from typing import Annotated, Any, cast +import asyncpg +import logfire +from annotated_types import MinLen from devtools import debug -from pydantic_ai import Agent +from pydantic_ai import Agent, CallContext, ModelRetry from pydantic_ai.agent import KnownModelName -# if you don't want to use logfire, just comment out these lines -import logfire - +# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up logfire.configure() -system_prompt = """\ -Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request. - -CREATE TABLE records AS ( - start_timestamp timestamp with time zone, - created_at timestamp with time zone, +DB_SCHEMA = """ +CREATE TABLE IF NOT EXISTS records ( + created_at timestamptz, + start_timestamp timestamptz, + end_timestamp timestamptz, trace_id text, span_id text, parent_span_id text, - kind span_kind, - end_timestamp timestamp with time zone, - level smallint, + level log_level, span_name text, message text, attributes_json_schema text, attributes jsonb, tags text[], - otel_links jsonb, - otel_events jsonb, is_exception boolean, - otel_status_code status_code, otel_status_message text, - otel_scope_name text, - otel_scope_version text, - otel_scope_attributes jsonb, - service_namespace text, - service_name text, - service_version text, - service_instance_id text, - process_pid integer + service_name text ); +""" + + +@dataclass +class Response: + sql_query: Annotated[str, MinLen(1)] + + +@dataclass +class Deps: + conn: asyncpg.Connection + + +model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'gemini-1.5-flash')) +agent: Agent[Deps, Response] = Agent(model, result_type=Response) + + +@agent.system_prompt +async def system_prompt() -> str: + return f"""\ +Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request. -today's date = 2024-10-09 +{DB_SCHEMA} + +today's date = {date.today()} Example request: show me records where foobar is false - response: SELECT * FROM records WHERE attributes->>'foobar' = false' + response: SELECT * FROM records WHERE attributes->>'foobar' = false +Example + request: show me records where attributes include the key "foobar" + response: SELECT * FROM records WHERE attributes ? 'foobar' Example request: show me records from yesterday response: SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day' @@ -65,16 +84,59 @@ """ -@dataclass -class Response: - sql_query: str +@agent.result_validator +async def validate_result(ctx: CallContext[Deps], result: Response) -> Response: + result.sql_query = result.sql_query.replace('\\', '') + lower_query = result.sql_query.lower() + if not lower_query.startswith('select'): + raise ModelRetry('Please a SELECT query') + try: + await ctx.deps.conn.execute(f'EXPLAIN {result.sql_query}') + except asyncpg.exceptions.PostgresError as e: + raise ModelRetry(f'Invalid query: {e}') from e + else: + return result -model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'gemini-1.5-flash')) -agent = Agent(model, result_type=Response, system_prompt=system_prompt, deps=None, retries=2) +async def main(): + if len(sys.argv) == 1: + prompt = 'show me logs from yesterday, with level "error"' + else: + prompt = sys.argv[1] -if __name__ == '__main__': - with debug.timer('SQL Generation'): - result = agent.run_sync('show me logs from yesterday, with level "error"') + async with database_connect('postgresql://postgres@localhost', 'pydantic_ai_sql_gen') as conn: + deps = Deps(conn) + result = await agent.run(prompt, deps=deps) debug(result.response.sql_query) + + +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false +@asynccontextmanager +async def database_connect(server_dsn: str, database: str) -> AsyncGenerator[Any, None]: + with logfire.span('check and create DB'): + conn = await asyncpg.connect(server_dsn) + try: + db_exists = await conn.fetchval('SELECT 1 FROM pg_database WHERE datname = $1', database) + if not db_exists: + await conn.execute(f'CREATE DATABASE {database}') + finally: + await conn.close() + + conn = await asyncpg.connect(f'{server_dsn}/{database}') + try: + with logfire.span('create schema'): + async with conn.transaction(): + if not db_exists: + await conn.execute( + "CREATE TYPE log_level AS ENUM ('debug', 'info', 'warning', 'error', 'critical')" + ) + await conn.execute(DB_SCHEMA) + yield conn + finally: + await conn.close() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/weather.py b/examples/weather.py index 8aa3b0fd..0e51d241 100644 --- a/examples/weather.py +++ b/examples/weather.py @@ -6,7 +6,7 @@ Run with: - uv run -m examples.weather + uv run --extra examples -m examples.weather """ import asyncio @@ -14,16 +14,15 @@ from dataclasses import dataclass from typing import Any, cast +import logfire from devtools import debug from httpx import AsyncClient from pydantic_ai import Agent, CallContext, ModelRetry from pydantic_ai.agent import KnownModelName -# if you don't want to use logfire, just comment out these lines -import logfire - -logfire.configure() +# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up +logfire.configure(send_to_logfire='if-token-present') @dataclass @@ -53,12 +52,16 @@ async def get_lat_lng(ctx: CallContext[Deps], location_description: str) -> dict 'q': location_description, 'api_key': ctx.deps.geo_api_key, } - r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params) - r.raise_for_status() - data = r.json() - if not data: + with logfire.span('calling geocode API', params=params) as span: + r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params) + r.raise_for_status() + data = r.json() + span.set_attribute('response', data) + + if data: + return {'lat': data[0]['lat'], 'lng': data[0]['lon']} + else: raise ModelRetry('Could not find the location') - return {'lat': data[0]['lat'], 'lng': data[0]['lon']} @weather_agent.retriever_context @@ -79,9 +82,12 @@ async def get_weather(ctx: CallContext[Deps], lat: float, lng: float) -> dict[st 'location': f'{lat},{lng}', 'units': 'metric', } - r = await ctx.deps.client.get('https://api.tomorrow.io/v4/weather/realtime', params=params) - r.raise_for_status() - data = r.json() + with logfire.span('calling weather API', params=params) as span: + r = await ctx.deps.client.get('https://api.tomorrow.io/v4/weather/realtime', params=params) + r.raise_for_status() + data = r.json() + span.set_attribute('response', data) + values = data['data']['values'] # https://docs.tomorrow.io/reference/data-layers-weather-codes code_lookup = { diff --git a/pydantic_ai/_result.py b/pydantic_ai/_result.py index a70aa934..c586e005 100644 --- a/pydantic_ai/_result.py +++ b/pydantic_ai/_result.py @@ -33,7 +33,7 @@ def __post_init__(self): self._is_async = inspect.iscoroutinefunction(self.function) async def validate( - self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall + self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall | None ) -> ResultData: """Validate a result but calling the function. @@ -41,7 +41,7 @@ async def validate( result: The result data after Pydantic validation the message content. deps: The agent dependencies. retry: The current retry number. - tool_call: The original tool call message. + tool_call: The original tool call message, `None` if there was no tool call. Returns: Result of either the validated result data (ok) or a retry message (Err). @@ -59,11 +59,10 @@ async def validate( function = cast(Callable[[Any], ResultData], self.function) result_data = await _utils.run_in_executor(function, *args) except ModelRetry as r: - m = messages.ToolRetry( - tool_name=tool_call.tool_name, - content=r.message, - tool_id=tool_call.tool_id, - ) + m = messages.RetryPrompt(content=r.message) + if tool_call is not None: + m.tool_name = tool_call.tool_name + m.tool_id = tool_call.tool_id raise ToolRetryError(m) from r else: return result_data @@ -72,7 +71,7 @@ async def validate( class ToolRetryError(Exception): """Internal exception used to indicate a signal a `ToolRetry` message should be returned to the LLM.""" - def __init__(self, tool_retry: messages.ToolRetry): + def __init__(self, tool_retry: messages.RetryPrompt): self.tool_retry = tool_retry super().__init__() @@ -127,7 +126,7 @@ def validate(self, tool_call: messages.ToolCall) -> ResultData: else: result = self.type_adapter.validate_python(tool_call.args.args_object) except ValidationError as e: - m = messages.ToolRetry( + m = messages.RetryPrompt( tool_name=tool_call.tool_name, content=e.errors(include_url=False), tool_id=tool_call.tool_id, diff --git a/pydantic_ai/_retriever.py b/pydantic_ai/_retriever.py index bd940a2d..5ba3faef 100644 --- a/pydantic_ai/_retriever.py +++ b/pydantic_ai/_retriever.py @@ -104,13 +104,13 @@ def _call_args(self, deps: AgentDeps, args_dict: dict[str, Any]) -> tuple[list[A def _on_error( self, content: list[pydantic_core.ErrorDetails] | str, call_message: messages.ToolCall - ) -> messages.ToolRetry: + ) -> messages.RetryPrompt: self._current_retry += 1 if self._current_retry > self.max_retries: # TODO custom error with details of the retriever raise else: - return messages.ToolRetry( + return messages.RetryPrompt( tool_name=call_message.tool_name, content=content, tool_id=call_message.tool_id, diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 6ca71626..0b27afdf 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -253,11 +253,18 @@ async def _handle_model_response( if model_response.role == 'llm-response': # plain string response if self._allow_text_result: - return _utils.Either(left=cast(ResultData, model_response.content)) + result_data_input = cast(ResultData, model_response.content) + try: + result_data = await self._validate_result(result_data_input, deps, None) + except _result.ToolRetryError as e: + self._incr_result_retry() + return _utils.Either(right=[e.tool_retry]) + else: + return _utils.Either(left=result_data) else: self._incr_result_retry() assert self._result_schema is not None - response = _messages.UserPrompt( + response = _messages.RetryPrompt( content='Plain text responses are not permitted, please call one of the functions instead.', ) return _utils.Either(right=[response]) @@ -290,7 +297,7 @@ async def _handle_model_response( assert_never(model_response) async def _validate_result( - self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall + self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None ) -> ResultData: for validator in self._result_validators: result_data = await validator.validate(result_data, deps, self._current_result_retry, tool_call) diff --git a/pydantic_ai/messages.py b/pydantic_ai/messages.py index bb89cedb..b8a06479 100644 --- a/pydantic_ai/messages.py +++ b/pydantic_ai/messages.py @@ -50,12 +50,12 @@ def model_response_object(self) -> dict[str, Any]: @dataclass -class ToolRetry: - tool_name: str +class RetryPrompt: content: list[pydantic_core.ErrorDetails] | str + tool_name: str | None = None tool_id: str | None = None timestamp: datetime = field(default_factory=datetime.now) - role: Literal['tool-retry'] = 'tool-retry' + role: Literal['retry-prompt'] = 'retry-prompt' def model_response(self) -> str: if isinstance(self.content, str): @@ -107,6 +107,6 @@ class LLMToolCalls: LLMMessage = Union[LLMResponse, LLMToolCalls] -Message = Union[SystemPrompt, UserPrompt, ToolReturn, ToolRetry, LLMMessage] +Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, LLMMessage] MessagesTypeAdapter = pydantic.TypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]]) diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index 32d24e01..eaf81857 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -34,8 +34,8 @@ LLMResponse, LLMToolCalls, Message, + RetryPrompt, ToolCall, - ToolRetry, ToolReturn, ) from . import AbstractToolDefinition, AgentModel, Model, cached_async_http_client @@ -169,8 +169,8 @@ def message_to_gemini(m: Message) -> _utils.Either[_GeminiTextPart, _GeminiConte elif m.role == 'tool-return': # ToolReturn -> return _utils.Either(right=_GeminiContent.function_return(m)) - elif m.role == 'tool-retry': - # ToolRetry -> + elif m.role == 'retry-prompt': + # RetryPrompt -> return _utils.Either(right=_GeminiContent.function_retry(m)) elif m.role == 'llm-response': # LLMResponse -> @@ -226,13 +226,16 @@ def function_call(cls, m: LLMToolCalls) -> _GeminiContent: @classmethod def function_return(cls, m: ToolReturn) -> _GeminiContent: f_response = _GeminiFunctionResponsePart.from_response(m.tool_name, m.model_response_object()) - return _GeminiContent(role='user', parts=[f_response]) + return cls(role='user', parts=[f_response]) @classmethod - def function_retry(cls, m: ToolRetry) -> _GeminiContent: - response = {'call_error': m.model_response()} - f_response = _GeminiFunctionResponsePart.from_response(m.tool_name, response) - return _GeminiContent(role='user', parts=[f_response]) + def function_retry(cls, m: RetryPrompt) -> _GeminiContent: + if m.tool_name is None: + part = _GeminiTextPart(text=m.model_response()) + else: + response = {'call_error': m.model_response()} + part = _GeminiFunctionResponsePart.from_response(m.tool_name, response) + return cls(role='user', parts=[part]) @dataclass diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 1561ebe1..978709ce 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -17,8 +17,8 @@ LLMResponse, LLMToolCalls, Message, + RetryPrompt, ToolCall, - ToolRetry, ToolReturn, ) from . import AbstractToolDefinition, AgentModel, Model, cached_async_http_client @@ -136,13 +136,16 @@ def map_message(message: Message) -> chat.ChatCompletionMessageParam: tool_call_id=_guard_tool_id(message), content=message.model_response_str(), ) - elif message.role == 'tool-retry': - # ToolRetry -> - return chat.ChatCompletionToolMessageParam( - role='tool', - tool_call_id=_guard_tool_id(message), - content=message.model_response(), - ) + elif message.role == 'retry-prompt': + # RetryPrompt -> + if message.tool_name is None: + return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response()) + else: + return chat.ChatCompletionToolMessageParam( + role='tool', + tool_call_id=_guard_tool_id(message), + content=message.model_response(), + ) elif message.role == 'llm-response': # LLMResponse -> return chat.ChatCompletionAssistantMessageParam(role='assistant', content=message.content) @@ -156,7 +159,7 @@ def map_message(message: Message) -> chat.ChatCompletionMessageParam: assert_never(message) -def _guard_tool_id(t: ToolCall | ToolReturn | ToolRetry) -> str: +def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str: """Type guard that checks a `tool_id` is not None both for static typing and runtime.""" assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}' return t.tool_id diff --git a/pydantic_ai/models/test.py b/pydantic_ai/models/test.py index a6361a46..80444e44 100644 --- a/pydantic_ai/models/test.py +++ b/pydantic_ai/models/test.py @@ -15,7 +15,7 @@ import pydantic_core from .. import _utils, shared -from ..messages import LLMMessage, LLMResponse, LLMToolCalls, Message, ToolCall, ToolRetry, ToolReturn +from ..messages import LLMMessage, LLMResponse, LLMToolCalls, Message, RetryPrompt, ToolCall, ToolReturn from . import AbstractToolDefinition, AgentModel, Model @@ -91,7 +91,7 @@ async def request(self, messages: list[Message]) -> tuple[LLMMessage, shared.Cos new_messages = messages[self.last_message_count :] self.last_message_count = len(messages) - new_retry_names = {m.tool_name for m in new_messages if isinstance(m, ToolRetry)} + new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)} if new_retry_names: calls = [ ToolCall.from_object(name, self.gen_retriever_args(args)) diff --git a/pyproject.toml b/pyproject.toml index 5470d359..2324f49a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,10 @@ dependencies = [ logfire = [ "logfire>=1.2.0", ] +examples = [ + "asyncpg>=0.30.0", + "logfire>=1.2.0", +] [tool.uv] dev-dependencies = [ @@ -59,7 +63,6 @@ dev-dependencies = [ "coverage[toml]>=7.6.2", "devtools>=0.12.2", "anyio>=4.5.0", - "logfire[httpx]>=1.2.0", ] [tool.hatch.build.targets.wheel] @@ -107,8 +110,7 @@ quote-style = "single" [tool.ruff.lint.per-file-ignores] "tests/**.py" = ["D"] -# see https://github.com/astral-sh/ruff/issues/13871 -"examples/**.py" = ["D103", "I001"] +"examples/**.py" = ["D103"] [tool.pyright] typeCheckingMode = "strict" diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 19cdaaa2..85aaad8c 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -16,9 +16,9 @@ ArgsObject, LLMResponse, LLMToolCalls, + RetryPrompt, SystemPrompt, ToolCall, - ToolRetry, ToolReturn, UserPrompt, ) @@ -437,7 +437,7 @@ async def get_location(loc_name: str) -> str: ], timestamp=IsNow(), ), - ToolRetry(tool_name='get_location', content='Wrong location, please try again', timestamp=IsNow()), + RetryPrompt(tool_name='get_location', content='Wrong location, please try again', timestamp=IsNow()), LLMToolCalls( calls=[ ToolCall( diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 5de77f56..0810bda2 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -4,8 +4,9 @@ import pydantic_core import pytest from inline_snapshot import snapshot +from pydantic import BaseModel -from pydantic_ai import Agent, CallContext +from pydantic_ai import Agent, CallContext, ModelRetry from pydantic_ai.messages import ( LLMMessage, LLMResponse, @@ -335,3 +336,50 @@ def test_call_all(): LLMResponse(content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}', timestamp=IsNow()), ] ) + + +def test_retry_str(): + call_count = 0 + + def try_again(messages: list[Message], _: AgentInfo) -> LLMMessage: + nonlocal call_count + call_count += 1 + + return LLMResponse(str(call_count)) + + agent = Agent(FunctionModel(try_again), deps=None) + + @agent.result_validator + async def validate_result(r: str) -> str: + if r == '1': + raise ModelRetry('Try again') + else: + return r + + result = agent.run_sync('') + assert result.response == snapshot('2') + + +def test_retry_result_type(): + call_count = 0 + + def try_again(messages: list[Message], _: AgentInfo) -> LLMMessage: + nonlocal call_count + call_count += 1 + + return LLMToolCalls(calls=[ToolCall.from_object('final_result', {'x': call_count})]) + + class Foo(BaseModel): + x: int + + agent = Agent(FunctionModel(try_again), result_type=Foo, deps=None) + + @agent.result_validator + async def validate_result(r: Foo) -> Foo: + if r.x == 1: + raise ModelRetry('Try again') + else: + return r + + result = agent.run_sync('') + assert result.response == snapshot(Foo(x=2)) diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index cd213657..f3c14572 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from pydantic_ai import Agent, ModelRetry -from pydantic_ai.messages import LLMResponse, LLMToolCalls, ToolCall, ToolRetry, ToolReturn, UserPrompt +from pydantic_ai.messages import LLMResponse, LLMToolCalls, RetryPrompt, ToolCall, ToolReturn, UserPrompt from pydantic_ai.models.test import TestModel, _chars, _JsonSchemaTestData # pyright: ignore[reportPrivateUsage] from tests.conftest import IsNow @@ -88,7 +88,7 @@ async def my_ret(x: int) -> str: calls=[ToolCall.from_object('my_ret', {'x': 0})], timestamp=IsNow(), ), - ToolRetry(tool_name='my_ret', content='First call failed', timestamp=IsNow()), + RetryPrompt(tool_name='my_ret', content='First call failed', timestamp=IsNow()), LLMToolCalls(calls=[ToolCall.from_object('my_ret', {'x': 1})], timestamp=IsNow()), ToolReturn(tool_name='my_ret', content='2', timestamp=IsNow()), LLMResponse(content='{"my_ret":"2"}', timestamp=IsNow()), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 4e74de9b..e97308f2 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -18,9 +18,9 @@ ArgsJson, LLMResponse, LLMToolCalls, + RetryPrompt, SystemPrompt, ToolCall, - ToolRetry, ToolReturn, UserPrompt, ) @@ -199,7 +199,7 @@ async def get_location(loc_name: str) -> str: ], timestamp=datetime.datetime(2024, 1, 1, 0, 0), ), - ToolRetry( + RetryPrompt( tool_name='get_location', content='Wrong location, please try again', tool_id='1', timestamp=IsNow() ), LLMToolCalls( diff --git a/tests/test_agent.py b/tests/test_agent.py index 6ce31ec9..4740cf34 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -8,8 +8,8 @@ LLMResponse, LLMToolCalls, Message, + RetryPrompt, ToolCall, - ToolRetry, UserPrompt, ) from pydantic_ai.models.function import AgentInfo, FunctionModel @@ -67,7 +67,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> LLMMessage: calls=[ToolCall.from_json('final_result', '{"a": "wrong", "b": "foo"}')], timestamp=IsNow(), ), - ToolRetry( + RetryPrompt( tool_name='final_result', content=[ { @@ -112,7 +112,7 @@ def validate_result(r: Foo) -> Foo: [ UserPrompt(content='Hello', timestamp=IsNow()), LLMToolCalls(calls=[ToolCall.from_json('final_result', '{"a": 41, "b": "foo"}')], timestamp=IsNow()), - ToolRetry(tool_name='final_result', content='"a" should be 42', timestamp=IsNow()), + RetryPrompt(tool_name='final_result', content='"a" should be 42', timestamp=IsNow()), LLMToolCalls(calls=[ToolCall.from_json('final_result', '{"a": 42, "b": "foo"}')], timestamp=IsNow()), ] ) @@ -141,7 +141,7 @@ def return_tuple(_: list[Message], info: AgentInfo) -> LLMMessage: [ UserPrompt(content='Hello', timestamp=IsNow()), LLMResponse(content='hello', timestamp=IsNow()), - UserPrompt( + RetryPrompt( content='Plain text responses are not permitted, please call one of the functions instead.', timestamp=IsNow(), ), diff --git a/uv.lock b/uv.lock index 354042ef..53472029 100644 --- a/uv.lock +++ b/uv.lock @@ -41,6 +41,66 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/86/4736ac618d82a20d87d2f92ae19441ebc7ac9e7a581d7e58bbe79233b24a/asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24", size = 27764 }, ] +[[package]] +name = "async-timeout" +version = "4.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/87/d6/21b30a550dafea84b1b8eee21b5e23fa16d010ae006011221f33dcd8d7f8/async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f", size = 8345 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028", size = 5721 }, +] + +[[package]] +name = "asyncpg" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/07/1650a8c30e3a5c625478fa8aafd89a8dd7d85999bf7169b16f54973ebf2c/asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e", size = 673143 }, + { url = "https://files.pythonhosted.org/packages/a0/9a/568ff9b590d0954553c56806766914c149609b828c426c5118d4869111d3/asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0", size = 645035 }, + { url = "https://files.pythonhosted.org/packages/de/11/6f2fa6c902f341ca10403743701ea952bca896fc5b07cc1f4705d2bb0593/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f", size = 2912384 }, + { url = "https://files.pythonhosted.org/packages/83/83/44bd393919c504ffe4a82d0aed8ea0e55eb1571a1dea6a4922b723f0a03b/asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af", size = 2947526 }, + { url = "https://files.pythonhosted.org/packages/08/85/e23dd3a2b55536eb0ded80c457b0693352262dc70426ef4d4a6fc994fa51/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75", size = 2895390 }, + { url = "https://files.pythonhosted.org/packages/9b/26/fa96c8f4877d47dc6c1864fef5500b446522365da3d3d0ee89a5cce71a3f/asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f", size = 3015630 }, + { url = "https://files.pythonhosted.org/packages/34/00/814514eb9287614188a5179a8b6e588a3611ca47d41937af0f3a844b1b4b/asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf", size = 568760 }, + { url = "https://files.pythonhosted.org/packages/f0/28/869a7a279400f8b06dd237266fdd7220bc5f7c975348fea5d1e6909588e9/asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50", size = 625764 }, + { url = "https://files.pythonhosted.org/packages/4c/0e/f5d708add0d0b97446c402db7e8dd4c4183c13edaabe8a8500b411e7b495/asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a", size = 674506 }, + { url = "https://files.pythonhosted.org/packages/6a/a0/67ec9a75cb24a1d99f97b8437c8d56da40e6f6bd23b04e2f4ea5d5ad82ac/asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed", size = 645922 }, + { url = "https://files.pythonhosted.org/packages/5c/d9/a7584f24174bd86ff1053b14bb841f9e714380c672f61c906eb01d8ec433/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a", size = 3079565 }, + { url = "https://files.pythonhosted.org/packages/a0/d7/a4c0f9660e333114bdb04d1a9ac70db690dd4ae003f34f691139a5cbdae3/asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956", size = 3109962 }, + { url = "https://files.pythonhosted.org/packages/3c/21/199fd16b5a981b1575923cbb5d9cf916fdc936b377e0423099f209e7e73d/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056", size = 3064791 }, + { url = "https://files.pythonhosted.org/packages/77/52/0004809b3427534a0c9139c08c87b515f1c77a8376a50ae29f001e53962f/asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454", size = 3188696 }, + { url = "https://files.pythonhosted.org/packages/52/cb/fbad941cd466117be58b774a3f1cc9ecc659af625f028b163b1e646a55fe/asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d", size = 567358 }, + { url = "https://files.pythonhosted.org/packages/3c/0a/0a32307cf166d50e1ad120d9b81a33a948a1a5463ebfa5a96cc5606c0863/asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f", size = 629375 }, + { url = "https://files.pythonhosted.org/packages/4b/64/9d3e887bb7b01535fdbc45fbd5f0a8447539833b97ee69ecdbb7a79d0cb4/asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e", size = 673162 }, + { url = "https://files.pythonhosted.org/packages/6e/eb/8b236663f06984f212a087b3e849731f917ab80f84450e943900e8ca4052/asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a", size = 637025 }, + { url = "https://files.pythonhosted.org/packages/cc/57/2dc240bb263d58786cfaa60920779af6e8d32da63ab9ffc09f8312bd7a14/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3", size = 3496243 }, + { url = "https://files.pythonhosted.org/packages/f4/40/0ae9d061d278b10713ea9021ef6b703ec44698fe32178715a501ac696c6b/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737", size = 3575059 }, + { url = "https://files.pythonhosted.org/packages/c3/75/d6b895a35a2c6506952247640178e5f768eeb28b2e20299b6a6f1d743ba0/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a", size = 3473596 }, + { url = "https://files.pythonhosted.org/packages/c8/e7/3693392d3e168ab0aebb2d361431375bd22ffc7b4a586a0fc060d519fae7/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af", size = 3641632 }, + { url = "https://files.pythonhosted.org/packages/32/ea/15670cea95745bba3f0352341db55f506a820b21c619ee66b7d12ea7867d/asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e", size = 560186 }, + { url = "https://files.pythonhosted.org/packages/7e/6b/fe1fad5cee79ca5f5c27aed7bd95baee529c1bf8a387435c8ba4fe53d5c1/asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305", size = 621064 }, + { url = "https://files.pythonhosted.org/packages/3a/22/e20602e1218dc07692acf70d5b902be820168d6282e69ef0d3cb920dc36f/asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70", size = 670373 }, + { url = "https://files.pythonhosted.org/packages/3d/b3/0cf269a9d647852a95c06eb00b815d0b95a4eb4b55aa2d6ba680971733b9/asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3", size = 634745 }, + { url = "https://files.pythonhosted.org/packages/8e/6d/a4f31bf358ce8491d2a31bfe0d7bcf25269e80481e49de4d8616c4295a34/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33", size = 3512103 }, + { url = "https://files.pythonhosted.org/packages/96/19/139227a6e67f407b9c386cb594d9628c6c78c9024f26df87c912fabd4368/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4", size = 3592471 }, + { url = "https://files.pythonhosted.org/packages/67/e4/ab3ca38f628f53f0fd28d3ff20edff1c975dd1cb22482e0061916b4b9a74/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4", size = 3496253 }, + { url = "https://files.pythonhosted.org/packages/ef/5f/0bf65511d4eeac3a1f41c54034a492515a707c6edbc642174ae79034d3ba/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba", size = 3662720 }, + { url = "https://files.pythonhosted.org/packages/e7/31/1513d5a6412b98052c3ed9158d783b1e09d0910f51fbe0e05f56cc370bc4/asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590", size = 560404 }, + { url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623 }, + { url = "https://files.pythonhosted.org/packages/b4/82/d94f3ed6921136a0ef40a825740eda19437ccdad7d92d924302dca1d5c9e/asyncpg-0.30.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f4e83f067b35ab5e6371f8a4c93296e0439857b4569850b178a01385e82e9ad", size = 673026 }, + { url = "https://files.pythonhosted.org/packages/4e/db/7db8b73c5d86ec9a21807f405e0698f8f637a8a3ca14b7b6fd4259b66bcf/asyncpg-0.30.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5df69d55add4efcd25ea2a3b02025b669a285b767bfbf06e356d68dbce4234ff", size = 644732 }, + { url = "https://files.pythonhosted.org/packages/eb/a0/1f1910659d08050cb3e8f7d82b32983974798d7fd4ddf7620b8e2023d4ac/asyncpg-0.30.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3479a0d9a852c7c84e822c073622baca862d1217b10a02dd57ee4a7a081f708", size = 2911761 }, + { url = "https://files.pythonhosted.org/packages/4d/53/5aa0d92488ded50bab2b6626430ed9743b0b7e2d864a2b435af1ccbf219a/asyncpg-0.30.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26683d3b9a62836fad771a18ecf4659a30f348a561279d6227dab96182f46144", size = 2946595 }, + { url = "https://files.pythonhosted.org/packages/c5/cd/d6d548d8ee721f4e0f7fbbe509bbac140d556c2e45814d945540c96cf7d4/asyncpg-0.30.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b982daf2441a0ed314bd10817f1606f1c28b1136abd9e4f11335358c2c631cb", size = 2890135 }, + { url = "https://files.pythonhosted.org/packages/46/f0/28df398b685dabee20235e24880e1f6486d84ae7e6b0d11bdebc17740e7a/asyncpg-0.30.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c06a3a50d014b303e5f6fc1e5f95eb28d2cee89cf58384b700da621e5d5e547", size = 3011889 }, + { url = "https://files.pythonhosted.org/packages/c8/07/8c7ffe6fe8bccff9b12fcb6410b1b2fa74b917fd8b837806a40217d5228b/asyncpg-0.30.0-cp39-cp39-win32.whl", hash = "sha256:1b11a555a198b08f5c4baa8f8231c74a366d190755aa4f99aacec5970afe929a", size = 569406 }, + { url = "https://files.pythonhosted.org/packages/05/51/f59e4df6d9b8937530d4b9fdee1598b93db40c631fe94ff3ce64207b7a95/asyncpg-0.30.0-cp39-cp39-win_amd64.whl", hash = "sha256:8b684a3c858a83cd876f05958823b68e8d14ec01bb0c0d14a6704c5bf9711773", size = 626581 }, +] + [[package]] name = "black" version = "24.8.0" @@ -532,11 +592,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/7f/37d9c3cbed1ef23b467c0c0039f35524595f8fd79f3acb54e647a0ccd590/logfire-1.2.0-py3-none-any.whl", hash = "sha256:edb2b441e418cf31877bd97e24b3755f873bb423f834cca66f315b25bde61ebd", size = 164724 }, ] -[package.optional-dependencies] -httpx = [ - { name = "opentelemetry-instrumentation-httpx" }, -] - [[package]] name = "logfire-api" version = "1.2.0" @@ -695,21 +750,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/7f/405c41d4f359121376c9d5117dcf68149b8122d3f6c718996d037bd4d800/opentelemetry_instrumentation-0.48b0-py3-none-any.whl", hash = "sha256:a69750dc4ba6a5c3eb67986a337185a25b739966d80479befe37b546fc870b44", size = 29449 }, ] -[[package]] -name = "opentelemetry-instrumentation-httpx" -version = "0.48b0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "opentelemetry-api" }, - { name = "opentelemetry-instrumentation" }, - { name = "opentelemetry-semantic-conventions" }, - { name = "opentelemetry-util-http" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d3/d9/c65d818607c16d1b7ea8d2de6111c6cecadf8d2fd38c1885a72733a7c6d3/opentelemetry_instrumentation_httpx-0.48b0.tar.gz", hash = "sha256:ee977479e10398931921fb995ac27ccdeea2e14e392cb27ef012fc549089b60a", size = 16931 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/fe/f2daa9d6d988c093b8c7b1d35df675761a8ece0b600b035dc04982746c9d/opentelemetry_instrumentation_httpx-0.48b0-py3-none-any.whl", hash = "sha256:d94f9d612c82d09fe22944d1904a30a464c19bea2ba76be656c99a28ad8be8e5", size = 13900 }, -] - [[package]] name = "opentelemetry-proto" version = "1.27.0" @@ -749,15 +789,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/7a/4f0063dbb0b6c971568291a8bc19a4ca70d3c185db2d956230dd67429dfc/opentelemetry_semantic_conventions-0.48b0-py3-none-any.whl", hash = "sha256:a0de9f45c413a8669788a38569c7e0a11ce6ce97861a628cca785deecdc32a1f", size = 149685 }, ] -[[package]] -name = "opentelemetry-util-http" -version = "0.48b0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d6/d7/185c494754340e0a3928fd39fde2616ee78f2c9d66253affaad62d5b7935/opentelemetry_util_http-0.48b0.tar.gz", hash = "sha256:60312015153580cc20f322e5cdc3d3ecad80a71743235bdb77716e742814623c", size = 7863 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/2e/36097c0a4d0115b8c7e377c90bab7783ac183bc5cb4071308f8959454311/opentelemetry_util_http-0.48b0-py3-none-any.whl", hash = "sha256:76f598af93aab50328d2a69c786beaedc8b6a7770f7a818cc307eb353debfffb", size = 6946 }, -] - [[package]] name = "packaging" version = "24.1" @@ -838,6 +869,10 @@ dependencies = [ ] [package.optional-dependencies] +examples = [ + { name = "asyncpg" }, + { name = "logfire" }, +] logfire = [ { name = "logfire" }, ] @@ -849,7 +884,6 @@ dev = [ { name = "devtools" }, { name = "dirty-equals" }, { name = "inline-snapshot" }, - { name = "logfire", extra = ["httpx"] }, { name = "mypy" }, { name = "pyright" }, { name = "pytest" }, @@ -859,9 +893,11 @@ dev = [ [package.metadata] requires-dist = [ + { name = "asyncpg", marker = "extra == 'examples'", specifier = ">=0.30.0" }, { name = "eval-type-backport", specifier = ">=0.2.0" }, { name = "griffe", specifier = ">=1.3.2" }, { name = "httpx", specifier = ">=0.27.2" }, + { name = "logfire", marker = "extra == 'examples'", specifier = ">=1.2.0" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=1.2.0" }, { name = "logfire-api", specifier = ">=1.2.0" }, { name = "openai", specifier = ">=1.51.2" }, @@ -875,7 +911,6 @@ dev = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "dirty-equals", specifier = ">=0.8.0" }, { name = "inline-snapshot", specifier = ">=0.13.3" }, - { name = "logfire", extras = ["httpx"], specifier = ">=1.2.0" }, { name = "mypy", specifier = ">=1.11.2" }, { name = "pyright", specifier = ">=1.1.384" }, { name = "pytest", specifier = ">=8.3.3" }, From 150a915865d5ec575473cd669c1f55abf51ffe8f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 23 Oct 2024 15:58:42 +0100 Subject: [PATCH 6/6] Apply suggestions from code review Co-authored-by: hyperlint-ai[bot] <154288675+hyperlint-ai[bot]@users.noreply.github.com> --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index 351c9fbc..1cece0c7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -22,7 +22,7 @@ Simple example of using Pydantic AI to construct a Pydantic model from a text in uv run --extra examples -m examples.pydantic_model ``` -This examples uses `openai:gpt-4o` by default but it works well with other modesl, e.g. you can run it +This examples uses `openai:gpt-4o` by default but it works well with other models, e.g. you can run it with Gemini using: ```bash