Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic streaming support #652

Open
wants to merge 40 commits into
base: dmontagu/refactor-streaming
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f7b4184
Refactor streaming
dmontagu Dec 18, 2024
eeaa952
WIP
dmontagu Dec 18, 2024
7c7977d
Merge main
dmontagu Jan 3, 2025
66dbc88
Merge branch 'main' into dmontagu/refactor-streaming
dmontagu Jan 3, 2025
2a01308
WIP
dmontagu Jan 3, 2025
807c2e5
Get streaming tests passing
dmontagu Jan 5, 2025
28b5d9d
Get gemini streaming working
dmontagu Jan 5, 2025
9c1a9b7
Further improve gemini streaming
dmontagu Jan 5, 2025
4845dfb
Fix more tests
dmontagu Jan 5, 2025
e0ea9c9
Fix more tests
dmontagu Jan 5, 2025
d8bed7a
Fix groq tests and examples
dmontagu Jan 6, 2025
4d75d6d
Use peekable stream to access timestamp in groq
dmontagu Jan 6, 2025
145a421
Merge branch 'main' into dmontagu/refactor-streaming
dmontagu Jan 6, 2025
7b959c4
Get openai tests passing
dmontagu Jan 6, 2025
c5590cc
Fix mistral tests
dmontagu Jan 6, 2025
ff36377
Remove the ability to yield None from StreamingResponse iterator
dmontagu Jan 6, 2025
5528ad1
Update example
dmontagu Jan 6, 2025
0ff2f87
Make PeekableAsyncStream work even if the stream can yield None
dmontagu Jan 6, 2025
f21aaf0
Remove PartStopEvent
dmontagu Jan 6, 2025
5eb5079
Fix for python 3.9
dmontagu Jan 7, 2025
c79ad0d
Merge main
dmontagu Jan 7, 2025
68efa65
Fix test
dmontagu Jan 7, 2025
f0a5f68
Add parts manager
dmontagu Jan 8, 2025
1ad34e1
Merge branch 'main' into dmontagu/refactor-streaming
dmontagu Jan 8, 2025
e024fe4
Fix syntax issue for 3.9
dmontagu Jan 8, 2025
8b712af
A bit more clean-up
dmontagu Jan 8, 2025
f647a38
Document the parts manager
dmontagu Jan 8, 2025
cfa72d0
Get rid of some of the excessive is_last tracking
dmontagu Jan 8, 2025
cd09247
Update a comment
dmontagu Jan 8, 2025
43e0b0f
Move much of the aiter implementation up to StreamedResponse
dmontagu Jan 8, 2025
94e8482
Move _parts_manager to StreamedResponse
dmontagu Jan 8, 2025
38c1e29
Move get() up to StreamedResponse
dmontagu Jan 8, 2025
9b28710
Remove the unused 'final' argument to get
dmontagu Jan 8, 2025
277ba84
Move usage() up to StreamedResponse
dmontagu Jan 8, 2025
7f71db3
Make MockAsyncStream generic
dmontagu Jan 8, 2025
626a69d
Add some tests of _parts_manager.py
dmontagu Jan 9, 2025
08a2f4f
Initial support for anthropic streaming
piercefreeman Jan 9, 2025
f14123c
Merge branch 'dmontagu/refactor-streaming' into feature/anthropic-str…
piercefreeman Jan 10, 2025
57e0ee9
MockAsyncStream requires global stream
piercefreeman Jan 10, 2025
256a3d9
Restore anthropic stream processor
piercefreeman Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 78 additions & 18 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from __future__ import annotations as _annotations

from collections.abc import AsyncIterator
from collections.abc import AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from json import JSONDecodeError, loads as json_loads
from typing import Any, Literal, Union, cast, overload

from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import usage
from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ArgsDict,
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponsePart,
ModelResponseStreamEvent,
RetryPromptPart,
SystemPromptPart,
TextPart,
Expand All @@ -38,11 +41,16 @@
from anthropic.types import (
Message as AnthropicMessage,
MessageParam,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
TextBlock,
TextBlockParam,
TextDelta,
ToolChoiceParam,
ToolParam,
ToolResultBlockParam,
Expand Down Expand Up @@ -231,22 +239,14 @@ def _process_response(response: AnthropicMessage) -> ModelResponse:

@staticmethod
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
"""TODO: Process a streamed response, and prepare a streaming response to return."""
# We don't yet support streamed responses from Anthropic, so we raise an error here for now.
# Streamed responses will be supported in a future release.

raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')

# Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
# RawMessageStartEvent
# RawMessageDeltaEvent
# RawMessageStopEvent
# RawContentBlockStartEvent
# RawContentBlockDeltaEvent
# RawContentBlockDeltaEvent
#
# We might refactor streaming internally before we implement this...
peekable_response = _utils.PeekableAsyncStream(response)
first_chunk = await peekable_response.peek()
if isinstance(first_chunk, _utils.Unset):
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')

# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
timestamp = datetime.now(tz=timezone.utc)
return AnthropicStreamedResponse(peekable_response, timestamp)

@staticmethod
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
Expand Down Expand Up @@ -342,3 +342,63 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
response_tokens=response_usage.output_tokens,
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
)


@dataclass
class AnthropicStreamedResponse(StreamedResponse):
"""Implementation of `StreamedResponse` for Anthropic models."""

_response: AsyncIterable[RawMessageStreamEvent]
_timestamp: datetime

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
current_block: TextBlock | ToolUseBlock | None = None
current_json: str = ''

async for event in self._response:
self._usage += _map_usage(event)

if isinstance(event, RawContentBlockStartEvent):
current_block = event.content_block
if isinstance(current_block, TextBlock) and current_block.text:
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
elif isinstance(current_block, ToolUseBlock):
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=current_block.id,
tool_name=current_block.name,
args=cast(dict[str, Any], current_block.input),
tool_call_id=current_block.id,
)
if maybe_event is not None:
yield maybe_event

elif isinstance(event, RawContentBlockDeltaEvent):
if isinstance(event.delta, TextDelta):
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
elif (
current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
):
# Try to parse the JSON immediately, otherwise cache the value for later. This handles
# cases where the JSON is not currently valid but will be valid once we stream more tokens.
try:
parsed_args = json_loads(current_json + event.delta.partial_json)
current_json = ''
except JSONDecodeError:
current_json += event.delta.partial_json
continue

# For tool calls, we need to handle partial JSON updates
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=current_block.id,
tool_name='',
args=parsed_args,
tool_call_id=current_block.id,
)
if maybe_event is not None:
yield maybe_event

elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
current_block = None

def timestamp(self) -> datetime:
return self._timestamp
171 changes: 166 additions & 5 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations as _annotations

import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import timezone
from functools import cached_property
from typing import Any, cast
from typing import Any, TypeVar, cast

import pytest
from inline_snapshot import snapshot
Expand All @@ -25,14 +26,24 @@
from ..conftest import IsNow, try_import

with try_import() as imports_successful:
from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropic, AsyncStream
from anthropic.types import (
ContentBlock,
InputJSONDelta,
Message as AnthropicMessage,
MessageDeltaUsage,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
TextBlock,
ToolUseBlock,
Usage as AnthropicUsage,
)
from anthropic.types.raw_message_delta_event import Delta

from pydantic_ai.models.anthropic import AnthropicModel

Expand All @@ -41,28 +52,69 @@
pytest.mark.anyio,
]

# Type variable for generic AsyncStream
T = TypeVar('T')


def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
assert m.client.api_key == 'foobar'
assert m.name() == 'anthropic:claude-3-5-haiku-latest'


if imports_successful():

class MockAsyncStream(AsyncStream[T]):
"""Mock implementation of AsyncStream for testing."""

def __init__(self, events: list[list[T]]):
self.events = events
self.stream_index = 0

def __aiter__(self) -> AsyncIterator[T]:
if self.stream_index >= len(self.events):
raise StopAsyncIteration

async def iterator() -> AsyncIterator[T]:
current_stream = self.events[self.stream_index]
for event in current_stream:
yield event
self.stream_index += 1

return iterator()

async def __anext__(self) -> T:
return await self._iterator.__anext__()

async def __aenter__(self) -> MockAsyncStream[T]:
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass


@dataclass
class MockAnthropic:
messages_: AnthropicMessage | list[AnthropicMessage] | None = None
messages_: AnthropicMessage | list[AnthropicMessage] | AsyncStream[RawMessageStreamEvent] | None = None
index = 0

@cached_property
def messages(self) -> Any:
return type('Messages', (), {'create': self.messages_create})

@classmethod
def create_mock(cls, messages_: AnthropicMessage | list[AnthropicMessage]) -> AsyncAnthropic:
def create_mock(
cls, messages_: AnthropicMessage | list[AnthropicMessage] | AsyncStream[RawMessageStreamEvent]
) -> AsyncAnthropic:
return cast(AsyncAnthropic, cls(messages_=messages_))

async def messages_create(self, *_args: Any, **_kwargs: Any) -> AnthropicMessage:
async def messages_create(
self, *_args: Any, stream: bool = False, **_kwargs: Any
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
assert self.messages_ is not None, '`messages` must be provided'
if isinstance(self.messages_, AsyncStream):
assert stream, 'stream must be True when using AsyncStream'
return self.messages_
if isinstance(self.messages_, list):
response = self.messages_[self.index]
else:
Expand Down Expand Up @@ -241,3 +293,112 @@ async def get_location(loc_name: str) -> str:
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)


async def test_stream_structured(allow_model_requests: None):
"""Test streaming structured responses with Anthropic's API.

This test simulates how Anthropic streams tool calls:
1. Message start
2. Tool block start with initial data
3. Tool block delta with additional data
4. Tool block stop
5. Update usage
6. Message stop
"""
stream: list[RawMessageStreamEvent] = [
RawMessageStartEvent(
type='message_start',
message=AnthropicMessage(
id='msg_123',
model='claude-3-5-haiku-latest',
role='assistant',
type='message',
content=[],
stop_reason=None,
usage=AnthropicUsage(input_tokens=20, output_tokens=0),
),
),
# Start tool block with initial data
RawContentBlockStartEvent(
type='content_block_start',
index=0,
content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={'first': 'One'}),
),
# Add more data through an incomplete JSON delta
RawContentBlockDeltaEvent(
type='content_block_delta',
index=0,
delta=InputJSONDelta(type='input_json_delta', partial_json='{"second":'),
),
RawContentBlockDeltaEvent(
type='content_block_delta',
index=0,
delta=InputJSONDelta(type='input_json_delta', partial_json='"Two"}'),
),
# Mark tool block as complete
RawContentBlockStopEvent(type='content_block_stop', index=0),
# Update the top-level message with usage
RawMessageDeltaEvent(
type='message_delta',
delta=Delta(
stop_reason='end_turn',
),
usage=MessageDeltaUsage(
output_tokens=5,
),
),
# Mark message as complete
RawMessageStopEvent(type='message_stop'),
]

done_stream: list[RawMessageStreamEvent] = [
RawMessageStartEvent(
type='message_start',
message=AnthropicMessage(
id='msg_123',
model='claude-3-5-haiku-latest',
role='assistant',
type='message',
content=[],
stop_reason=None,
usage=AnthropicUsage(input_tokens=0, output_tokens=0),
),
),
# Text block with final data
RawContentBlockStartEvent(
type='content_block_start',
index=0,
content_block=TextBlock(type='text', text='FINAL_PAYLOAD'),
),
RawContentBlockStopEvent(type='content_block_stop', index=0),
RawMessageStopEvent(type='message_stop'),
]

mock_client = MockAnthropic.create_mock(MockAsyncStream([stream, done_stream]))
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
agent = Agent(m)

tool_called = False

@agent.tool_plain
async def my_tool(first: str, second: str) -> int:
nonlocal tool_called
tool_called = True
return len(first) + len(second)

async with agent.run_stream('') as result:
assert not result.is_complete
chunks = [c async for c in result.stream(debounce_by=None)]

# The tool output doesn't echo any content to the stream, so we only get the final payload once when
# the block starts and once when it ends.
assert chunks == snapshot(
[
'FINAL_PAYLOAD',
'FINAL_PAYLOAD',
]
)
assert result.is_complete
assert result.usage() == snapshot(Usage(requests=2, request_tokens=20, response_tokens=5, total_tokens=25))
assert tool_called
1 change: 1 addition & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ async def bar(y: str) -> str:
assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"])


# TODO: Is this test still necessary now that heterogeneous streaming is allowed?
async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
responses = [
gemini_response(_content_model_response(ModelResponse.from_text('Hello '))),
Expand Down
11 changes: 6 additions & 5 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from inline_snapshot import snapshot
from pydantic import BaseModel

from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages
from pydantic_ai import Agent, UnexpectedModelBehavior, capture_run_messages
from pydantic_ai.messages import (
ArgsDict,
ArgsJson,
Expand Down Expand Up @@ -124,10 +124,11 @@ async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> A

assert chunks == snapshot([[1], [1, 2, 3, 4], [1, 2, 3, 4]])

async with agent.run_stream('Hello') as result:
with pytest.raises(UserError, match=r'stream_text\(\) can only be used with text responses'):
async for _ in result.stream_text():
pass
# TODO: Can we remove the following check now?
# async with agent.run_stream('Hello') as result:
# with pytest.raises(UserError, match=r'stream_text\(\) can only be used with text responses'):
# async for _ in result.stream_text():
# pass


async def test_streamed_text_stream():
Expand Down
Loading
Loading