Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve examples #18

Merged
merged 6 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
.DEFAULT_GOAL := all
sources = pydantic_ai tests

.PHONY: .uv # Check that uv is installed
.uv:
Expand All @@ -16,13 +15,13 @@ install: .uv .pre-commit

.PHONY: format # Format the code
format:
uv run ruff format $(sources)
uv run ruff check --fix --fix-only $(sources)
uv run ruff format
uv run ruff check --fix --fix-only

.PHONY: lint # Lint the code
lint:
uv run ruff format --check $(sources)
uv run ruff check $(sources)
uv run ruff format --check
uv run ruff check

.PHONY: typecheck-pyright
typecheck-pyright:
Expand Down
61 changes: 61 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Pydantic AI Examples

Examples of how to use Pydantic AI and what it can do.

## Usage

To run the examples, run:

```bash
uv run -m examples.<example_module_name>
```

## Examples

### `pydantic_model.py`

Simple example of using Pydantic AI to construct a Pydantic model from a text input.

```bash
uv run -m examples.pydantic_model
```

This examples uses `openai:gpt-4o` by default but it works well with other modesl, e.g. you can run it
hyperlint-ai[bot] marked this conversation as resolved.
Show resolved Hide resolved
with Gemini using:

```bash
PYDANTIC_AI_MODEL=gemini-1.5-pro uv run -m examples.pydantic_model
```

(or `PYDANTIC_AI_MODEL=gemini-1.5-flash...`)

### `sql_gen.py`

Example demonstrating how to use Pydantic AI to generate SQL queries based on user input.
hyperlint-ai[bot] marked this conversation as resolved.
Show resolved Hide resolved
hyperlint-ai[bot] marked this conversation as resolved.
Show resolved Hide resolved

```bash
uv run -m examples.sql_gen
```

This model uses `gemini-1.5-flash` by default since Gemini is good at single shot queries.

### `weather.py`

Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question.
hyperlint-ai[bot] marked this conversation as resolved.
Show resolved Hide resolved

In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities,
the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use
the `get_weather` tool to get the weather.

To run this example properly, you'll need two extra API keys:
* A weather API key from [tomorrow.io](https://www.tomorrow.io/weather-api/) set via `WEATHER_API_KEY`
* A geocoding API key from [geocode.maps.co](https://geocode.maps.co/) set via `GEO_API_KEY`

**(Note if either key is missing, the code will fall back to dummy data.)**

```bash
uv run -m examples.weather
```

This example uses `openai:gpt-4o` by default. Gemini seems to be unable to handle the multiple tool
calls.
15 changes: 0 additions & 15 deletions examples/parse_model.py

This file was deleted.

32 changes: 32 additions & 0 deletions examples/pydantic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Simple example of using Pydantic AI to construct a Pydantic model from a text input.

Run with:

uv run -m examples.pydantic_model
"""

import os
from typing import cast

from pydantic import BaseModel

from pydantic_ai import Agent
from pydantic_ai.agent import KnownModelName

# if you don't want to use logfire, just comment out these lines
import logfire

logfire.configure()


class MyModel(BaseModel):
city: str
country: str


model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o'))
agent = Agent(model, result_type=MyModel, deps=None)

if __name__ == '__main__':
result = agent.run_sync('The windy city in the US of A.')
print(result.response)
21 changes: 19 additions & 2 deletions examples/sql_gen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
"""Example demonstrating how to use Pydantic AI to generate SQL queries based on user input.

Run with:

uv run -m examples.sql_gen
"""

import os
from dataclasses import dataclass
from typing import cast

from pydantic_ai import Agent
from devtools import debug

from pydantic_ai import Agent
from pydantic_ai.agent import KnownModelName

# if you don't want to use logfire, just comment out these lines
import logfire

logfire.configure()

system_prompt = """\
Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request.

Expand Down Expand Up @@ -54,7 +70,8 @@ class Response:
sql_query: str


agent = Agent('gemini-1.5-flash', result_type=Response, system_prompt=system_prompt, deps=None)
model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'gemini-1.5-flash'))
agent = Agent(model, result_type=Response, system_prompt=system_prompt, deps=None, retries=2)


if __name__ == '__main__':
Expand Down
139 changes: 111 additions & 28 deletions examples/weather.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,132 @@
"""Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question.

In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities,
the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use
the `get_weather` tool to get the weather.

Run with:

uv run -m examples.weather
"""

import asyncio
import os
from dataclasses import dataclass
from typing import Any, cast

from devtools import debug
from httpx import AsyncClient

from pydantic_ai import Agent
from pydantic_ai import Agent, CallContext, ModelRetry
from pydantic_ai.agent import KnownModelName

try:
import logfire
except ImportError:
pass
else:
logfire.configure()
# if you don't want to use logfire, just comment out these lines
import logfire

weather_agent: Agent[None, str] = Agent('openai:gpt-4o', system_prompt='Be concise, reply with one sentence.')
logfire.configure()


@weather_agent.retriever_plain
async def get_lat_lng(location_description: str) -> dict[str, float]:
"""
Get the latitude and longitude of a location.
@dataclass
class Deps:
client: AsyncClient
weather_api_key: str | None
geo_api_key: str | None


model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o'))
weather_agent: Agent[Deps, str] = Agent(model, system_prompt='Be concise, reply with one sentence.', retries=2)


@weather_agent.retriever_context
async def get_lat_lng(ctx: CallContext[Deps], location_description: str) -> dict[str, float]:
"""Get the latitude and longitude of a location.

Args:
ctx: The context.
location_description: A description of a location.
"""
if 'london' in location_description.lower():
if ctx.deps.geo_api_key is None:
# if no API key is provided, return a dummy response (London)
return {'lat': 51.1, 'lng': -0.1}
elif 'wiltshire' in location_description.lower():
return {'lat': 51.1, 'lng': -2.11}
else:
return {'lat': 0, 'lng': 0}

params = {
'q': location_description,
'api_key': ctx.deps.geo_api_key,
}
r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params)
r.raise_for_status()
data = r.json()
if not data:
raise ModelRetry('Could not find the location')
return {'lat': data[0]['lat'], 'lng': data[0]['lon']}

@weather_agent.retriever_plain
async def get_whether(lat: float, lng: float) -> str:
"""
Get the weather at a location.

@weather_agent.retriever_context
async def get_weather(ctx: CallContext[Deps], lat: float, lng: float) -> dict[str, Any]:
"""Get the weather at a location.

Args:
ctx: The context.
lat: Latitude of the location.
lng: Longitude of the location.
"""
if abs(lat - 51.1) < 0.1 and abs(lng + 0.1) < 0.1:
# it always rains in London
return 'Raining'
else:
return 'Sunny'
if ctx.deps.weather_api_key is None:
# if no API key is provided, return a dummy response
return {'temperature': '21 °C', 'description': 'Sunny'}

params = {
'apikey': ctx.deps.weather_api_key,
'location': f'{lat},{lng}',
'units': 'metric',
}
r = await ctx.deps.client.get('https://api.tomorrow.io/v4/weather/realtime', params=params)
r.raise_for_status()
data = r.json()
values = data['data']['values']
# https://docs.tomorrow.io/reference/data-layers-weather-codes
code_lookup = {
1000: 'Clear, Sunny',
1100: 'Mostly Clear',
1101: 'Partly Cloudy',
1102: 'Mostly Cloudy',
1001: 'Cloudy',
2000: 'Fog',
2100: 'Light Fog',
4000: 'Drizzle',
4001: 'Rain',
4200: 'Light Rain',
4201: 'Heavy Rain',
5000: 'Snow',
5001: 'Flurries',
5100: 'Light Snow',
5101: 'Heavy Snow',
6000: 'Freezing Drizzle',
6001: 'Freezing Rain',
6200: 'Light Freezing Rain',
6201: 'Heavy Freezing Rain',
7000: 'Ice Pellets',
7101: 'Heavy Ice Pellets',
7102: 'Light Ice Pellets',
8000: 'Thunderstorm',
}
return {
'temperature': f'{values['temperatureApparent']:0.0f}°C',
'description': code_lookup.get(values['weatherCode'], 'Unknown'),
}


async def main():
async with AsyncClient() as client:
logfire.instrument_httpx()
# create a free API key at https://www.tomorrow.io/weather-api/
weather_api_key = os.getenv('WEATHER_API_KEY')
# create a free API key at https://geocode.maps.co/
geo_api_key = os.getenv('GEO_API_KEY')
deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key)
result = await weather_agent.run('What is the weather like in London and in Wiltshire?', deps=deps)
debug(result)
print('Response:', result.response)


if __name__ == '__main__':
result = weather_agent.run_sync('What is the weather like in West London and in Wiltshire?')
debug(result)
asyncio.run(main())
13 changes: 8 additions & 5 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from collections.abc import Awaitable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, Generic, Literal, cast, overload
from typing import Any, Callable, Generic, Literal, cast, final, overload

import logfire_api
from pydantic import ValidationError
Expand All @@ -12,18 +12,19 @@
from . import _result, _retriever as _r, _system_prompt, _utils, messages as _messages, models, shared
from .shared import AgentDeps, ResultData

__all__ = ('Agent',)
__all__ = 'Agent', 'KnownModelName'
KnownModelName = Literal[
'openai:gpt-4o', 'openai:gpt-4-turbo', 'openai:gpt-4', 'openai:gpt-3.5-turbo', 'gemini-1.5-flash', 'gemini-1.5-pro'
]
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')


@final
@dataclass(init=False)
class Agent(Generic[AgentDeps, ResultData]):
"""Main class for creating "agents" - a way to have a specific type of "conversation" with an LLM."""

# slots mostly for my sanity — knowing what attributes are available
# dataclass fields mostly for my sanity — knowing what attributes are available
model: models.Model | None
_result_schema: _result.ResultSchema[ResultData] | None
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
Expand Down Expand Up @@ -86,9 +87,10 @@ async def run(
The result of the run.
"""
if model is not None:
model_ = models.infer_model(model)
custom_model = model_ = models.infer_model(model)
elif self.model is not None:
model_ = self.model
custom_model = None
else:
raise shared.UserError('`model` must be set either when creating the agent or when calling it.')

Expand All @@ -111,7 +113,7 @@ async def run(
cost = shared.Cost()

with _logfire.span(
'agent run {prompt=}', prompt=user_prompt, agent=self, model=model_, model_name=model_.name()
'agent run {prompt=}', prompt=user_prompt, agent=self, custom_model=custom_model, model_name=model_.name()
) as run_span:
try:
while True:
Expand Down Expand Up @@ -279,6 +281,7 @@ async def _handle_model_response(
for call in model_response.calls:
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
# should this be a retry error?
raise shared.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
coros.append(retriever.run(deps, call))
new_messages = await asyncio.gather(*coros)
Expand Down
Loading
Loading