-
Notifications
You must be signed in to change notification settings - Fork 353
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Gradio Demo for weather agent example (#230)
Co-authored-by: David Montague <[email protected]> Co-authored-by: Sydney Runkle <[email protected]> Co-authored-by: sydney-runkle <[email protected]>
- Loading branch information
1 parent
db7b539
commit 01270cd
Showing
5 changed files
with
737 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
import json | ||
import os | ||
|
||
from httpx import AsyncClient | ||
|
||
from pydantic_ai.messages import ToolCallPart, ToolReturnPart | ||
from pydantic_ai_examples.weather_agent import Deps, weather_agent | ||
|
||
try: | ||
import gradio as gr | ||
except ImportError as e: | ||
raise ImportError( | ||
'Please install gradio with `pip install gradio`. You must use python>=3.10.' | ||
) from e | ||
|
||
TOOL_TO_DISPLAY_NAME = {'get_lat_lng': 'Geocoding API', 'get_weather': 'Weather API'} | ||
|
||
client = AsyncClient() | ||
weather_api_key = os.getenv('WEATHER_API_KEY') | ||
# create a free API key at https://geocode.maps.co/ | ||
geo_api_key = os.getenv('GEO_API_KEY') | ||
deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key) | ||
|
||
|
||
async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list): | ||
chatbot.append({'role': 'user', 'content': prompt}) | ||
yield gr.Textbox(interactive=False, value=''), chatbot, gr.skip() | ||
async with weather_agent.run_stream( | ||
prompt, deps=deps, message_history=past_messages | ||
) as result: | ||
for message in result.new_messages(): | ||
for call in message.parts: | ||
if isinstance(call, ToolCallPart): | ||
call_args = ( | ||
call.args.args_json | ||
if hasattr(call.args, 'args_json') | ||
else json.dumps(call.args.args_dict) | ||
) | ||
gr_message = { | ||
'role': 'assistant', | ||
'content': 'Parameters: ' + call_args, | ||
'metadata': { | ||
'title': f'🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}', | ||
'id': call.tool_call_id, | ||
}, | ||
} | ||
chatbot.append(gr_message) | ||
if isinstance(call, ToolReturnPart): | ||
for gr_message in chatbot: | ||
if ( | ||
gr_message.get('metadata', {}).get('id', '') | ||
== call.tool_call_id | ||
): | ||
gr_message['content'] += ( | ||
f'\nOutput: {json.dumps(call.content)}' | ||
) | ||
yield gr.skip(), chatbot, gr.skip() | ||
chatbot.append({'role': 'assistant', 'content': ''}) | ||
async for message in result.stream_text(): | ||
chatbot[-1]['content'] = message | ||
yield gr.skip(), chatbot, gr.skip() | ||
past_messages = result.all_messages() | ||
|
||
yield gr.Textbox(interactive=True), gr.skip(), past_messages | ||
|
||
|
||
async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData): | ||
new_history = chatbot[: retry_data.index] | ||
previous_prompt = chatbot[retry_data.index]['content'] | ||
past_messages = past_messages[: retry_data.index] | ||
async for update in stream_from_agent(previous_prompt, new_history, past_messages): | ||
yield update | ||
|
||
|
||
def undo(chatbot, past_messages: list, undo_data: gr.UndoData): | ||
new_history = chatbot[: undo_data.index] | ||
past_messages = past_messages[: undo_data.index] | ||
return chatbot[undo_data.index]['content'], new_history, past_messages | ||
|
||
|
||
def select_data(message: gr.SelectData) -> str: | ||
return message.value['text'] | ||
|
||
|
||
with gr.Blocks() as demo: | ||
gr.HTML( | ||
""" | ||
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%"> | ||
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto"> | ||
<div> | ||
<h1 style="margin: 0 0 1rem 0">Weather Assistant</h1> | ||
<h3 style="margin: 0 0 0.5rem 0"> | ||
This assistant answer your weather questions. | ||
</h3> | ||
</div> | ||
</div> | ||
""" | ||
) | ||
past_messages = gr.State([]) | ||
chatbot = gr.Chatbot( | ||
label='Packing Assistant', | ||
type='messages', | ||
avatar_images=(None, 'https://ai.pydantic.dev/img/logo-white.svg'), | ||
examples=[ | ||
{'text': 'What is the weather like in Miami?'}, | ||
{'text': 'What is the weather like in London?'}, | ||
], | ||
) | ||
with gr.Row(): | ||
prompt = gr.Textbox( | ||
lines=1, | ||
show_label=False, | ||
placeholder='What is the weather like in New York City?', | ||
) | ||
generation = prompt.submit( | ||
stream_from_agent, | ||
inputs=[prompt, chatbot, past_messages], | ||
outputs=[prompt, chatbot, past_messages], | ||
) | ||
chatbot.example_select(select_data, None, [prompt]) | ||
chatbot.retry( | ||
handle_retry, [chatbot, past_messages], [prompt, chatbot, past_messages] | ||
) | ||
chatbot.undo(undo, [chatbot, past_messages], [prompt, chatbot, past_messages]) | ||
|
||
|
||
if __name__ == '__main__': | ||
demo.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.