Skip to content

Commit

Permalink
add streaming to chat app (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Nov 14, 2024
1 parent 48d2fd1 commit 6dc3e8d
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 69 deletions.
14 changes: 13 additions & 1 deletion docs/examples/chat-app.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ Demonstrates:

* reusing chat history
* serializing messages
* streaming responses

This demonstrates storing chat history between requests and using it to give the model context for new responses.

Most of the complex logic here is in `chat_app.html` which includes the page layout and JavaScript to handle the chat.
Most of the complex logic here is between `chat_app.py` which streams the response to the browser,
and `chat_app.ts` which renders messages in the browser.

## Running the Example

Expand All @@ -27,10 +29,20 @@ TODO screenshot.

## Example Code

Python code that runs the chat app:

```py title="chat_app.py"
#! pydantic_ai_examples/chat_app.py
```

Simple HTML page to render the app:

```html title="chat_app.html"
#! pydantic_ai_examples/chat_app.html
```

TypeScript to handle rendering the messages, to keep this simple (and at the risk of offending frontend developers) the typescript code is passed to the browser as plain text and transpiled in the browser.

```ts title="chat_app.ts"
#! pydantic_ai_examples/chat_app.ts
```
5 changes: 5 additions & 0 deletions pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import datetime
from typing import Generic, TypeVar, cast

import logfire_api
Expand Down Expand Up @@ -273,6 +274,10 @@ def cost(self) -> Cost:
"""
return self.cost_so_far + self._stream_response.cost()

def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
return self._stream_response.timestamp()

async def validate_structured_result(
self, message: messages.ModelStructuredResponse, *, allow_partial: bool = False
) -> ResultData:
Expand Down
77 changes: 15 additions & 62 deletions pydantic_ai_examples/chat_app.html
Original file line number Diff line number Diff line change
Expand Up @@ -58,71 +58,24 @@ <h1>Chat App</h1>
</main>
</body>
</html>
<script src="https://cdnjs.cloudflare.com/ajax/libs/typescript/5.6.3/typescript.min.js" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script type="module">
import { marked } from 'https://cdn.jsdelivr.net/npm/marked/lib/marked.esm.js';

function addMessages(lines) {
const messages = lines.filter(line => line.length > 1).map((line) => JSON.parse(line))
const parent = document.getElementById('conversation');
for (const message of messages) {
let msgDiv = document.createElement('div');
msgDiv.classList.add('border-top', 'pt-2', message.role);
msgDiv.innerHTML = marked.parse(message.content);
parent.appendChild(msgDiv);
}
// to let me write TypeScript, without adding the burden of npm we do a dirty, non-production-ready hack
// and transpile the TypeScript code in the browser
// this is (arguably) A neat demo trick, but not suitable for production!
async function loadTs() {
const response = await fetch('/chat_app.ts');
const tsCode = await response.text();
const jsCode = window.ts.transpile(tsCode, { target: "es2015" });
let script = document.createElement('script');
script.type = 'module';
script.text = jsCode;
document.body.appendChild(script);
}

function onError(error) {
console.error(error);
loadTs().catch((e) => {
console.error(e);
document.getElementById('error').classList.remove('d-none');
document.getElementById('spinner').classList.remove('active');
}

async function fetchResponse(response) {
let text = '';
if (response.ok) {
const reader = response.body.getReader();
while (true) {
const {done, value} = await reader.read();
if (done) {
break;
}
text += new TextDecoder().decode(value);
const lines = text.split('\n');
if (lines.length > 1) {
addMessages(lines.slice(0, -1));
text = lines[lines.length - 1];
}
}
addMessages(text.split('\n'));
let input = document.getElementById('prompt-input')
input.disabled = false;
input.focus();
} else {
const text = await response.text();
console.error(`Unexpected response: ${response.status}`, {response, text});
throw new Error(`Unexpected response: ${response.status}`);
}
}

async function onSubmit(e) {
e.preventDefault();
const spinner = document.getElementById('spinner');
spinner.classList.add('active');
const body = new FormData(e.target);

let input = document.getElementById('prompt-input')
input.value = '';
input.disabled = true;

const response = await fetch('/chat/', {method: 'POST', body});
await fetchResponse(response);
spinner.classList.remove('active');
}

// call onSubmit when form is submitted (e.g. user clicks the send button or hits Enter)
document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError));

// load messages on page load
fetch('/chat/').then(fetchResponse).catch(onError);
});
</script>
27 changes: 21 additions & 6 deletions pydantic_ai_examples/chat_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from pydantic import Field, TypeAdapter

from pydantic_ai import Agent
from pydantic_ai.messages import Message, MessagesTypeAdapter, UserPrompt
from pydantic_ai.messages import (
Message,
MessagesTypeAdapter,
ModelTextResponse,
UserPrompt,
)

# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')
Expand All @@ -32,6 +37,12 @@ async def index() -> HTMLResponse:
return HTMLResponse((THIS_DIR / 'chat_app.html').read_bytes())


@app.get('/chat_app.ts')
async def main_ts() -> Response:
"""Get the raw typescript code, it's compiled in the browser, forgive me."""
return Response((THIS_DIR / 'chat_app.ts').read_bytes(), media_type='text/plain')


@app.get('/chat/')
async def get_chat() -> Response:
msgs = database.get_messages()
Expand All @@ -49,12 +60,16 @@ async def stream_messages():
yield MessageTypeAdapter.dump_json(UserPrompt(content=prompt)) + b'\n'
# get the chat history so far to pass as context to the agent
messages = list(database.get_messages())
response = await agent.run(prompt, message_history=messages)
# run the agent with the user prompt and the chat history
async with agent.run_stream(prompt, message_history=messages) as result:
async for text in result.stream(debounce_by=0.01):
# text here is a `str` and the frontend wants
# JSON encoded ModelTextResponse, so we create one
m = ModelTextResponse(content=text, timestamp=result.timestamp())
yield MessageTypeAdapter.dump_json(m) + b'\n'

# add new messages (e.g. the user prompt and the agent response in this case) to the database
database.add_messages(response.new_messages_json())
# stream the last message which will be the agent response, we can't just yield `new_messages_json()`
# since we already stream the user prompt
yield MessageTypeAdapter.dump_json(response.all_messages()[-1]) + b'\n'
database.add_messages(result.new_messages_json())

return StreamingResponse(stream_messages(), media_type='text/plain')

Expand Down
90 changes: 90 additions & 0 deletions pydantic_ai_examples/chat_app.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// BIG FAT WARNING: to avoid the complexity of npm, this typescript is compiled in the browser
// there's currently no static type checking

import { marked } from 'https://cdnjs.cloudflare.com/ajax/libs/marked/15.0.0/lib/marked.esm.js'
const convElement = document.getElementById('conversation')

const promptInput = document.getElementById('prompt-input') as HTMLInputElement
const spinner = document.getElementById('spinner')

// stream the response and render messages as each chunk is received
// data is sent as newline-delimited JSON
async function onFetchResponse(response: Response): Promise<void> {
let text = ''
let decoder = new TextDecoder()
if (response.ok) {
const reader = response.body.getReader()
while (true) {
const {done, value} = await reader.read()
if (done) {
break
}
text += decoder.decode(value)
addMessages(text)
spinner.classList.remove('active')
}
addMessages(text)
promptInput.disabled = false
promptInput.focus()
} else {
const text = await response.text()
console.error(`Unexpected response: ${response.status}`, {response, text})
throw new Error(`Unexpected response: ${response.status}`)
}
}

// The format of messages, this matches pydantic-ai both for brevity and understanding
// in production, you might not want to keep this format all the way to the frontend
interface Message {
role: string
content: string
timestamp: string
}

// take raw response text and render messages into the `#conversation` element
// Message timestamp is assumed to be a unique identifier of a message, and is used to deduplicate
// hence you can send data about the same message multiple times, and it will be updated
// instead of creating a new message elements
function addMessages(responseText: string) {
const lines = responseText.split('\n')
const messages: Message[] = lines.filter(line => line.length > 1).map(j => JSON.parse(j))
for (const message of messages) {
// we use the timestamp as a crude element id
const {timestamp, role, content} = message
const id = `msg-${timestamp}`
let msgDiv = document.getElementById(id)
if (!msgDiv) {
msgDiv = document.createElement('div')
msgDiv.id = id
msgDiv.title = `${role} at ${timestamp}`
msgDiv.classList.add('border-top', 'pt-2', role)
convElement.appendChild(msgDiv)
}
msgDiv.innerHTML = marked.parse(content)
}
window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' })
}

function onError(error: any) {
console.error(error)
document.getElementById('error').classList.remove('d-none')
document.getElementById('spinner').classList.remove('active')
}

async function onSubmit(e: SubmitEvent): Promise<void> {
e.preventDefault()
spinner.classList.add('active')
const body = new FormData(e.target as HTMLFormElement)

promptInput.value = ''
promptInput.disabled = true

const response = await fetch('/chat/', {method: 'POST', body})
await onFetchResponse(response)
}

// call onSubmit when the form is submitted (e.g. user clicks the send button or hits Enter)
document.querySelector('form').addEventListener('submit', (e) => onSubmit(e).catch(onError))

// load messages on page load
fetch('/chat/').then(onFetchResponse).catch(onError)
3 changes: 3 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.result import Cost
from tests.conftest import IsNow

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -51,6 +52,8 @@ async def ret_a(x: str) -> str:
response = await result.get_data()
assert response == snapshot('{"ret_a":"a-apple"}')
assert result.is_complete
assert result.cost() == snapshot(Cost())
assert result.timestamp() == IsNow(tz=timezone.utc)
assert result.all_messages() == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
Expand Down

0 comments on commit 6dc3e8d

Please sign in to comment.