Skip to content

Commit

Permalink
Add Gradio Demo for weather agent example (#230)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <[email protected]>
Co-authored-by: Sydney Runkle <[email protected]>
Co-authored-by: sydney-runkle <[email protected]>
  • Loading branch information
4 people authored Jan 8, 2025
1 parent db7b539 commit 01270cd
Show file tree
Hide file tree
Showing 5 changed files with 737 additions and 8 deletions.
22 changes: 22 additions & 0 deletions docs/examples/weather-agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Demonstrates:
* [tools](../tools.md)
* [agent dependencies](../dependencies.md)
* [streaming text responses](../results.md#streaming-text)
* Building a [Gradio](https://www.gradio.app/) UI for the agent

In this case the idea is a "weather" agent — the user can ask for the weather in multiple locations,
the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use
Expand All @@ -28,3 +29,24 @@ python/uv-run -m pydantic_ai_examples.weather_agent
```python {title="pydantic_ai_examples/weather_agent.py"}
#! examples/pydantic_ai_examples/weather_agent.py
```

## Running the UI

You can build multi-turn chat applications for your agent with [Gradio](https://www.gradio.app/), a framework for building AI web applications entirely in python. Gradio comes with built-in chat components and agent support so the entire UI will be implemented in a single python file!

Here's what the UI looks like for the weather agent:

{{ video('c549d8d8827ded15f326f998e428e6c3', 25) }}

Note, to run the UI, you'll need Python 3.10+.

```bash
pip install gradio>=5.9.0
python/uv-run -m pydantic_ai_examples.weather_agent_gradio
```

## UI Code

```python {title="pydantic_ai_examples/weather_agent_gradio.py"}
#! pydantic_ai_examples/weather_agent_gradio.py
```
130 changes: 130 additions & 0 deletions examples/pydantic_ai_examples/weather_agent_gradio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations as _annotations

import json
import os

from httpx import AsyncClient

from pydantic_ai.messages import ToolCallPart, ToolReturnPart
from pydantic_ai_examples.weather_agent import Deps, weather_agent

try:
import gradio as gr
except ImportError as e:
raise ImportError(
'Please install gradio with `pip install gradio`. You must use python>=3.10.'
) from e

TOOL_TO_DISPLAY_NAME = {'get_lat_lng': 'Geocoding API', 'get_weather': 'Weather API'}

client = AsyncClient()
weather_api_key = os.getenv('WEATHER_API_KEY')
# create a free API key at https://geocode.maps.co/
geo_api_key = os.getenv('GEO_API_KEY')
deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key)


async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list):
chatbot.append({'role': 'user', 'content': prompt})
yield gr.Textbox(interactive=False, value=''), chatbot, gr.skip()
async with weather_agent.run_stream(
prompt, deps=deps, message_history=past_messages
) as result:
for message in result.new_messages():
for call in message.parts:
if isinstance(call, ToolCallPart):
call_args = (
call.args.args_json
if hasattr(call.args, 'args_json')
else json.dumps(call.args.args_dict)
)
gr_message = {
'role': 'assistant',
'content': 'Parameters: ' + call_args,
'metadata': {
'title': f'🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}',
'id': call.tool_call_id,
},
}
chatbot.append(gr_message)
if isinstance(call, ToolReturnPart):
for gr_message in chatbot:
if (
gr_message.get('metadata', {}).get('id', '')
== call.tool_call_id
):
gr_message['content'] += (
f'\nOutput: {json.dumps(call.content)}'
)
yield gr.skip(), chatbot, gr.skip()
chatbot.append({'role': 'assistant', 'content': ''})
async for message in result.stream_text():
chatbot[-1]['content'] = message
yield gr.skip(), chatbot, gr.skip()
past_messages = result.all_messages()

yield gr.Textbox(interactive=True), gr.skip(), past_messages


async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData):
new_history = chatbot[: retry_data.index]
previous_prompt = chatbot[retry_data.index]['content']
past_messages = past_messages[: retry_data.index]
async for update in stream_from_agent(previous_prompt, new_history, past_messages):
yield update


def undo(chatbot, past_messages: list, undo_data: gr.UndoData):
new_history = chatbot[: undo_data.index]
past_messages = past_messages[: undo_data.index]
return chatbot[undo_data.index]['content'], new_history, past_messages


def select_data(message: gr.SelectData) -> str:
return message.value['text']


with gr.Blocks() as demo:
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%">
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto">
<div>
<h1 style="margin: 0 0 1rem 0">Weather Assistant</h1>
<h3 style="margin: 0 0 0.5rem 0">
This assistant answer your weather questions.
</h3>
</div>
</div>
"""
)
past_messages = gr.State([])
chatbot = gr.Chatbot(
label='Packing Assistant',
type='messages',
avatar_images=(None, 'https://ai.pydantic.dev/img/logo-white.svg'),
examples=[
{'text': 'What is the weather like in Miami?'},
{'text': 'What is the weather like in London?'},
],
)
with gr.Row():
prompt = gr.Textbox(
lines=1,
show_label=False,
placeholder='What is the weather like in New York City?',
)
generation = prompt.submit(
stream_from_agent,
inputs=[prompt, chatbot, past_messages],
outputs=[prompt, chatbot, past_messages],
)
chatbot.example_select(select_data, None, [prompt])
chatbot.retry(
handle_retry, [chatbot, past_messages], [prompt, chatbot, past_messages]
)
chatbot.undo(undo, [chatbot, past_messages], [prompt, chatbot, past_messages])


if __name__ == '__main__':
demo.launch()
1 change: 1 addition & 0 deletions examples/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"rich>=13.9.2",
"uvicorn>=0.32.0",
"devtools>=0.12.2",
"gradio>=5.9.0; python_version>'3.9'",
]

[tool.hatch.build.targets.wheel]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ venvPath = ".venv"
executionEnvironments = [
{ root = "tests", reportUnusedFunction = false },
]
exclude = ["examples/pydantic_ai_examples/weather_agent_gradio.py"]

[tool.pytest.ini_options]
testpaths = "tests"
Expand Down
Loading

0 comments on commit 01270cd

Please sign in to comment.