diff --git a/docs/examples/chat-app.md b/docs/examples/chat-app.md index 97330987..3af5719a 100644 --- a/docs/examples/chat-app.md +++ b/docs/examples/chat-app.md @@ -8,10 +8,12 @@ Demonstrates: * reusing chat history * serializing messages +* streaming responses This demonstrates storing chat history between requests and using it to give the model context for new responses. -Most of the complex logic here is in `chat_app.html` which includes the page layout and JavaScript to handle the chat. +Most of the complex logic here is between `chat_app.py` which streams the response to the browser, +and `chat_app.ts` which renders messages in the browser. ## Running the Example @@ -27,10 +29,20 @@ TODO screenshot. ## Example Code +Python code that runs the chat app: + ```py title="chat_app.py" #! pydantic_ai_examples/chat_app.py ``` +Simple HTML page to render the app: + ```html title="chat_app.html" #! pydantic_ai_examples/chat_app.html ``` + +TypeScript to handle rendering the messages, to keep this simple (and at the risk of offending frontend developers) the typescript code is passed to the browser as plain text and transpiled in the browser. + +```ts title="chat_app.ts" +#! pydantic_ai_examples/chat_app.ts +``` diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index ee6de72b..73b327ed 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from dataclasses import dataclass +from datetime import datetime from typing import Generic, TypeVar, cast import logfire_api @@ -273,6 +274,10 @@ def cost(self) -> Cost: """ return self.cost_so_far + self._stream_response.cost() + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self._stream_response.timestamp() + async def validate_structured_result( self, message: messages.ModelStructuredResponse, *, allow_partial: bool = False ) -> ResultData: diff --git a/pydantic_ai_examples/chat_app.html b/pydantic_ai_examples/chat_app.html index 20be83ae..da626cde 100644 --- a/pydantic_ai_examples/chat_app.html +++ b/pydantic_ai_examples/chat_app.html @@ -58,71 +58,24 @@

Chat App

+ diff --git a/pydantic_ai_examples/chat_app.py b/pydantic_ai_examples/chat_app.py index 905065a6..3e825839 100644 --- a/pydantic_ai_examples/chat_app.py +++ b/pydantic_ai_examples/chat_app.py @@ -16,7 +16,12 @@ from pydantic import Field, TypeAdapter from pydantic_ai import Agent -from pydantic_ai.messages import Message, MessagesTypeAdapter, UserPrompt +from pydantic_ai.messages import ( + Message, + MessagesTypeAdapter, + ModelTextResponse, + UserPrompt, +) # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') @@ -32,6 +37,12 @@ async def index() -> HTMLResponse: return HTMLResponse((THIS_DIR / 'chat_app.html').read_bytes()) +@app.get('/chat_app.ts') +async def main_ts() -> Response: + """Get the raw typescript code, it's compiled in the browser, forgive me.""" + return Response((THIS_DIR / 'chat_app.ts').read_bytes(), media_type='text/plain') + + @app.get('/chat/') async def get_chat() -> Response: msgs = database.get_messages() @@ -49,12 +60,16 @@ async def stream_messages(): yield MessageTypeAdapter.dump_json(UserPrompt(content=prompt)) + b'\n' # get the chat history so far to pass as context to the agent messages = list(database.get_messages()) - response = await agent.run(prompt, message_history=messages) + # run the agent with the user prompt and the chat history + async with agent.run_stream(prompt, message_history=messages) as result: + async for text in result.stream(debounce_by=0.01): + # text here is a `str` and the frontend wants + # JSON encoded ModelTextResponse, so we create one + m = ModelTextResponse(content=text, timestamp=result.timestamp()) + yield MessageTypeAdapter.dump_json(m) + b'\n' + # add new messages (e.g. the user prompt and the agent response in this case) to the database - database.add_messages(response.new_messages_json()) - # stream the last message which will be the agent response, we can't just yield `new_messages_json()` - # since we already stream the user prompt - yield MessageTypeAdapter.dump_json(response.all_messages()[-1]) + b'\n' + database.add_messages(result.new_messages_json()) return StreamingResponse(stream_messages(), media_type='text/plain') diff --git a/pydantic_ai_examples/chat_app.ts b/pydantic_ai_examples/chat_app.ts new file mode 100644 index 00000000..86b60a64 --- /dev/null +++ b/pydantic_ai_examples/chat_app.ts @@ -0,0 +1,90 @@ +// BIG FAT WARNING: to avoid the complexity of npm, this typescript is compiled in the browser +// there's currently no static type checking + +import { marked } from 'https://cdnjs.cloudflare.com/ajax/libs/marked/15.0.0/lib/marked.esm.js' +const convElement = document.getElementById('conversation') + +const promptInput = document.getElementById('prompt-input') as HTMLInputElement +const spinner = document.getElementById('spinner') + +// stream the response and render messages as each chunk is received +// data is sent as newline-delimited JSON +async function onFetchResponse(response: Response): Promise { + let text = '' + let decoder = new TextDecoder() + if (response.ok) { + const reader = response.body.getReader() + while (true) { + const {done, value} = await reader.read() + if (done) { + break + } + text += decoder.decode(value) + addMessages(text) + spinner.classList.remove('active') + } + addMessages(text) + promptInput.disabled = false + promptInput.focus() + } else { + const text = await response.text() + console.error(`Unexpected response: ${response.status}`, {response, text}) + throw new Error(`Unexpected response: ${response.status}`) + } +} + +// The format of messages, this matches pydantic-ai both for brevity and understanding +// in production, you might not want to keep this format all the way to the frontend +interface Message { + role: string + content: string + timestamp: string +} + +// take raw response text and render messages into the `#conversation` element +// Message timestamp is assumed to be a unique identifier of a message, and is used to deduplicate +// hence you can send data about the same message multiple times, and it will be updated +// instead of creating a new message elements +function addMessages(responseText: string) { + const lines = responseText.split('\n') + const messages: Message[] = lines.filter(line => line.length > 1).map(j => JSON.parse(j)) + for (const message of messages) { + // we use the timestamp as a crude element id + const {timestamp, role, content} = message + const id = `msg-${timestamp}` + let msgDiv = document.getElementById(id) + if (!msgDiv) { + msgDiv = document.createElement('div') + msgDiv.id = id + msgDiv.title = `${role} at ${timestamp}` + msgDiv.classList.add('border-top', 'pt-2', role) + convElement.appendChild(msgDiv) + } + msgDiv.innerHTML = marked.parse(content) + } + window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }) +} + +function onError(error: any) { + console.error(error) + document.getElementById('error').classList.remove('d-none') + document.getElementById('spinner').classList.remove('active') +} + +async function onSubmit(e: SubmitEvent): Promise { + e.preventDefault() + spinner.classList.add('active') + const body = new FormData(e.target as HTMLFormElement) + + promptInput.value = '' + promptInput.disabled = true + + const response = await fetch('/chat/', {method: 'POST', body}) + await onFetchResponse(response) +} + +// call onSubmit when the form is submitted (e.g. user clicks the send button or hits Enter) +document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError)) + +// load messages on page load +fetch('/chat/').then(onFetchResponse).catch(onError) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index a9212b8a..cb8a3385 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -21,6 +21,7 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.result import Cost from tests.conftest import IsNow pytestmark = pytest.mark.anyio @@ -51,6 +52,8 @@ async def ret_a(x: str) -> str: response = await result.get_data() assert response == snapshot('{"ret_a":"a-apple"}') assert result.is_complete + assert result.cost() == snapshot(Cost()) + assert result.timestamp() == IsNow(tz=timezone.utc) assert result.all_messages() == snapshot( [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),