From f7b4184bdf29621692b92b9cd69c3e194dc32b8a Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 18 Dec 2024 07:14:39 -0700 Subject: [PATCH 01/34] Refactor streaming --- pydantic_ai_slim/pydantic_ai/messages.py | 58 +++++- .../pydantic_ai/models/__init__.py | 6 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 189 +++++++++++------- pydantic_ai_slim/pydantic_ai/result.py | 2 + 4 files changed, 172 insertions(+), 83 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 481cb4d5..6fd88ecf 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import datetime from typing import Annotated, Any, Literal, Union @@ -233,10 +233,58 @@ def from_tool_call(cls, tool_call: ToolCallPart) -> Self: return cls([tool_call]) -ModelMessage = Union[ModelRequest, ModelResponse] +ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')] """Any message send to or returned by a model.""" -ModelMessagesTypeAdapter = pydantic.TypeAdapter( - list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True) -) +ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True)) """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages.""" + + +@dataclass +class TextPartDelta: + content_delta: str + part_delta_kind: Literal['text'] = 'text' + + def apply(self, part: TextPart) -> TextPart: + return replace(part, content=part.content + self.content_delta) + + +@dataclass +class ToolCallPartDelta: + args_json_delta: str + part_delta_kind: Literal['tool_call'] = 'tool_call' + + def apply(self, part: ToolCallPart) -> ToolCallPart: + assert isinstance(part.args, ArgsJson), 'Cannot apply deltas to non-JSON tool arguments' + updated_json = part.args.args_json + self.args_json_delta + return replace(part, args=ArgsJson(updated_json)) + + +ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')] + + +@dataclass +class PartStartEvent: + """If multiple PartStartEvents are received with the same index, the new one should fully replace the old one.""" + + index: int + part: ModelResponsePart + event_kind: Literal['part_start'] = 'part_start' + + +@dataclass +class PartDeltaEvent: + index: int + delta: ModelResponsePartDelta + event_kind: Literal['part_delta'] = 'part_delta' + + +@dataclass +class PartStopEvent: + index: int + event_kind: Literal['part_stop'] = 'part_stop' + + +ModelResponseStreamEvent = Annotated[ + Union[PartStartEvent, PartDeltaEvent, PartStopEvent], pydantic.Discriminator('event_kind') +] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 0291e6bc..8950a679 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -16,7 +16,7 @@ import httpx from ..exceptions import UserError -from ..messages import ModelMessage, ModelResponse +from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent from ..settings import ModelSettings if TYPE_CHECKING: @@ -180,7 +180,7 @@ def timestamp(self) -> datetime: class StreamStructuredResponse(ABC): """Streamed response from an LLM when calling a tool.""" - def __aiter__(self) -> AsyncIterator[None]: + def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent | None]: # TODO: Should we drop the None? I think so """Stream the response as an async iterable, building up the tool call as it goes. This is an async iterator that yields `None` to avoid doing the work of building the final tool call when @@ -189,7 +189,7 @@ def __aiter__(self) -> AsyncIterator[None]: return self @abstractmethod - async def __anext__(self) -> None: + async def __anext__(self) -> ModelResponseStreamEvent | None: # TODO: Should we drop the None? I think so """Process the next chunk of the response, see above for why this returns `None`.""" raise NotImplementedError() diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 95759598..f852bef3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator, Iterable, Iterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone @@ -10,18 +10,23 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import UnexpectedModelBehavior, _utils, result -from .._utils import guard_tool_call_id as _guard_tool_call_id +from .. import _utils, result +from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( ArgsJson, ModelMessage, ModelRequest, ModelResponse, ModelResponsePart, + ModelResponseStreamEvent, + PartDeltaEvent, + PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, + TextPartDelta, ToolCallPart, + ToolCallPartDelta, ToolReturnPart, UserPromptPart, ) @@ -30,7 +35,6 @@ from ..tools import ToolDefinition from . import ( AgentModel, - EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse, @@ -152,12 +156,20 @@ async def request( return self._process_response(response), _map_cost(response) @asynccontextmanager - async def request_stream( + async def request_stream_text( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamTextResponse]: response = await self._completions_create(messages, True, model_settings) async with response: - yield await self._process_streamed_response(response) + yield OpenAIStreamTextResponse(response) + + @asynccontextmanager + async def request_stream_structured( + self, messages: list[ModelMessage], model_settings: ModelSettings | None + ) -> AsyncIterator[StreamStructuredResponse]: + response = await self._completions_create(messages, True, model_settings) + async with response: + yield OpenAIStreamStructuredResponse(response) @overload async def _completions_create( @@ -214,35 +226,6 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id)) return ModelResponse(items, timestamp=timestamp) - @staticmethod - async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse: - """Process a streamed response, and prepare a streaming response to return.""" - timestamp: datetime | None = None - start_cost = Cost() - # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content` - while True: - try: - chunk = await response.__anext__() - except StopAsyncIteration as e: - raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e - - timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) - start_cost += _map_cost(chunk) - - if chunk.choices: - delta = chunk.choices[0].delta - - if delta.content is not None: - return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost) - elif delta.tool_calls is not None: - return OpenAIStreamStructuredResponse( - response, - {c.index: c for c in delta.tool_calls}, - timestamp, - start_cost, - ) - # else continue until we get either delta.content or delta.tool_calls - @classmethod def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" @@ -299,19 +282,16 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio class OpenAIStreamTextResponse(StreamTextResponse): """Implementation of `StreamTextResponse` for OpenAI models.""" - _first: str | None _response: AsyncStream[ChatCompletionChunk] - _timestamp: datetime - _cost: result.Cost + _timestamp: datetime | None = field(default=None, init=False) + _cost: result.Cost = field(default_factory=result.Cost, init=False) _buffer: list[str] = field(default_factory=list, init=False) async def __anext__(self) -> None: - if self._first is not None: - self._buffer.append(self._first) - self._first = None - return None - chunk = await self._response.__anext__() + if self._timestamp is None: + self._timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc) + self._cost += _map_cost(chunk) try: choice = chunk.choices[0] @@ -332,7 +312,11 @@ def cost(self) -> Cost: return self._cost def timestamp(self) -> datetime: - return self._timestamp + # TODO: the following seems problematic + return self._timestamp or datetime.now(tz=timezone.utc) + + +_CONTENT_INDEX = 0 @dataclass @@ -340,47 +324,102 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse): """Implementation of `StreamStructuredResponse` for OpenAI models.""" _response: AsyncStream[ChatCompletionChunk] - _delta_tool_calls: dict[int, ChoiceDeltaToolCall] - _timestamp: datetime - _cost: result.Cost - async def __anext__(self) -> None: - chunk = await self._response.__anext__() - self._cost += _map_cost(chunk) - try: - choice = chunk.choices[0] - except IndexError: - raise StopAsyncIteration() + _timestamp: datetime | None = field(default=None, init=False) + _cost: result.Cost = field(default_factory=result.Cost, init=False) + _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) - if choice.finish_reason is not None: - raise StopAsyncIteration() + _content_part: TextPart | None = field(default=None, init=False) + _tool_call_parts: dict[int, ToolCallPart] = field(default_factory=dict, init=False) + _async_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) - assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}' + def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + if self._async_iterator is None: + self._async_iterator = self._get_async_iterator() + return self._async_iterator - for new in choice.delta.tool_calls or []: - if current := self._delta_tool_calls.get(new.index): - if current.function is None: - current.function = new.function - elif new.function is not None: - current.function.name = _utils.add_optional(current.function.name, new.function.name) - current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments) - else: - self._delta_tool_calls[new.index] = new + async def __anext__(self) -> ModelResponseStreamEvent | None: + if self._async_iterator is None: + self._async_iterator = self._get_async_iterator() + return await self._async_iterator.__anext__() - def get(self, *, final: bool = False) -> ModelResponse: - items: list[ModelResponsePart] = [] - for c in self._delta_tool_calls.values(): - if f := c.function: - if f.name is not None and f.arguments is not None: - items.append(ToolCallPart.from_json(f.name, f.arguments, c.id)) + async def _get_async_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + async for chunk in self._response: + if self._timestamp is None: + self._timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc) + + self._cost += _map_cost(chunk) + try: + choice = chunk.choices[0] + except IndexError: + raise StopAsyncIteration() + + for e in self._update_parts_for_content_delta(choice.delta.content): + yield e + + for new in choice.delta.tool_calls or []: + if current := self._delta_tool_calls.get(new.index): + if new.function is not None: + if current.function is None: + replace_existing_part = True + current.function = new.function + else: + replace_existing_part = bool(new.function.name) + current.function.name = _utils.add_optional(current.function.name, new.function.name) + current.function.arguments = _utils.add_optional( + current.function.arguments, new.function.arguments + ) + for e in self._update_parts_for_tool_call_delta(current, replace_existing_part): + yield e + else: + self._delta_tool_calls[new.index] = new + for e in self._update_parts_for_tool_call_delta(new, True): + yield e - return ModelResponse(items, timestamp=self._timestamp) + if choice.finish_reason is not None: + raise StopAsyncIteration() + + def get(self, *, final: bool = False) -> ModelResponse: + items: list[ModelResponsePart] = [self._content_part] if self._content_part is not None else [] + items.extend([self._tool_call_parts[k] for k in sorted(self._tool_call_parts.keys())]) + return ModelResponse(items, timestamp=self.timestamp()) def cost(self) -> Cost: return self._cost def timestamp(self) -> datetime: - return self._timestamp + return self._timestamp or _now_utc() + + def _update_parts_for_content_delta(self, choice_delta_content: str | None) -> Iterator[ModelResponseStreamEvent]: + if choice_delta_content is None: + return + + existing_content = self._content_part + if existing_content is None: + part = TextPart(content=choice_delta_content) + self._content_part = part + yield PartStartEvent(index=_CONTENT_INDEX, part=part) + else: + delta = TextPartDelta(content_delta=choice_delta_content) + self._content_part = delta.apply(existing_content) + yield PartDeltaEvent(index=_CONTENT_INDEX, delta=delta) + + def _update_parts_for_tool_call_delta( + self, tc: ChoiceDeltaToolCall, replace: bool + ) -> Iterator[ModelResponseStreamEvent]: + if tc.function is None: + return + + if replace: + assert tc.function.name is not None + new_part = ToolCallPart.from_json(tc.function.name, tc.function.arguments or '', tc.id) + self._tool_call_parts[tc.index] = new_part + yield PartStartEvent(index=tc.index, part=new_part) + else: + assert (existing_part := self._tool_call_parts.get(tc.index)) is not None + delta = ToolCallPartDelta(args_json_delta=tc.function.arguments or '') + self._tool_call_parts[tc.index] = delta.apply(existing_part) + yield PartDeltaEvent(index=tc.index, delta=delta) def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 281b8f5f..d15961d7 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -149,6 +149,8 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu Returns: An async iterable of the response data. """ + # TODO: Drop the following isinstance and use stream_structured even for text (and make that work) + # Rename stream_structured to stream_responses or similar (the idea is just that you aren't getting diffs) if isinstance(self._stream_response, models.StreamTextResponse): async for text in self.stream_text(debounce_by=debounce_by): yield cast(ResultData, text) From eeaa952390bdca3cff990de92573a035b65124fc Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:02:16 -0700 Subject: [PATCH 02/34] WIP --- docs/api/models/base.md | 2 +- docs/models.md | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 101 ++++++-------- pydantic_ai_slim/pydantic_ai/messages.py | 1 + .../pydantic_ai/models/__init__.py | 51 +------ .../pydantic_ai/models/anthropic.py | 8 +- .../pydantic_ai/models/function.py | 50 +------ pydantic_ai_slim/pydantic_ai/models/gemini.py | 13 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 13 +- .../pydantic_ai/models/mistral.py | 15 +-- pydantic_ai_slim/pydantic_ai/models/openai.py | 11 +- pydantic_ai_slim/pydantic_ai/models/test.py | 9 +- pydantic_ai_slim/pydantic_ai/result.py | 126 ++++++++---------- tests/models/test_mistral.py | 4 +- 14 files changed, 146 insertions(+), 260 deletions(-) diff --git a/docs/api/models/base.md b/docs/api/models/base.md index ea2ed5f0..9f847e9a 100644 --- a/docs/api/models/base.md +++ b/docs/api/models/base.md @@ -8,7 +8,7 @@ - AgentModel - AbstractToolDefinition - StreamTextResponse - - StreamStructuredResponse + - StreamedResponse - ALLOW_MODEL_REQUESTS - check_allow_model_requests - override_allow_model_requests diff --git a/docs/models.md b/docs/models.md index af271181..95e1935a 100644 --- a/docs/models.md +++ b/docs/models.md @@ -456,7 +456,7 @@ This in turn will require you to implement the following other abstract base cla * [`AgentModel`][pydantic_ai.models.AgentModel] * [`StreamTextResponse`][pydantic_ai.models.StreamTextResponse] -* [`StreamStructuredResponse`][pydantic_ai.models.StreamStructuredResponse] +* [`StreamedResponse`][pydantic_ai.models.StreamedResponse] The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py). diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 542086f2..3e1cffd9 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -21,6 +21,7 @@ models, result, ) +from .messages import PartStartEvent, TextPart, ToolCallPart from .result import ResultData from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -968,74 +969,56 @@ async def _process_function_tools( async def _handle_streamed_model_response( self, - model_response: models.EitherStreamedResponse, + streamed_response: models.StreamedResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage], - ) -> ( - _MarkFinalResult[models.EitherStreamedResponse] - | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]] - ): + ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]: """Process a streamed response from the model. Returns: Either a final result or a tuple of the model response and the tool responses for the next request. If a final result is returned, the conversation should end. """ - if isinstance(model_response, models.StreamTextResponse): - # plain string response - if self._allow_text_result: - return _MarkFinalResult(model_response, None) - else: - self._incr_result_retry() - response = _messages.RetryPromptPart( - content='Plain text responses are not permitted, please call one of the functions instead.', - ) - # stream the response, so cost is correct - async for _ in model_response: - pass - - text = ''.join(model_response.get(final=True)) - return _messages.ModelResponse([_messages.TextPart(text)]), [response] - elif isinstance(model_response, models.StreamStructuredResponse): - if self._result_schema is not None: - # if there's a result schema, iterate over the stream until we find at least one tool - # NOTE: this means we ignore any other tools called here - structured_msg = model_response.get() - while not structured_msg.parts: - try: - await model_response.__anext__() - except StopAsyncIteration: - break - structured_msg = model_response.get() - - if match := self._result_schema.find_tool(structured_msg.parts): - call, _ = match - return _MarkFinalResult(model_response, call.tool_name) - - # the model is calling a tool function, consume the response to get the next message - async for _ in model_response: - pass - model_response_msg = model_response.get() - if not model_response_msg.parts: - raise exceptions.UnexpectedModelBehavior('Received empty tool call message') - - # we now run all tool functions in parallel - tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] - parts: list[_messages.ModelRequestPart] = [] - for item in model_response_msg.parts: - if isinstance(item, _messages.ToolCallPart): - call = item - if tool := self._function_tools.get(call.tool_name): - tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name)) - else: - parts.append(self._unknown_tool(call.tool_name)) + received_text = False + + async for maybe_part_event in streamed_response: + if maybe_part_event is None: + continue + if isinstance(maybe_part_event, PartStartEvent): + new_part = maybe_part_event.part + if isinstance(new_part, TextPart): + received_text = True + if self._allow_text_result: + return _MarkFinalResult(streamed_response, None) + elif isinstance(new_part, ToolCallPart): + if self._result_schema is not None and (match := self._result_schema.find_tool([new_part])): + call, _ = match + return _MarkFinalResult(streamed_response, call.tool_name) + else: + assert_never(new_part) - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): - task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) - parts.extend(task_results) - return model_response_msg, parts - else: - assert_never(model_response) + tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] + parts: list[_messages.ModelRequestPart] = [] + model_response = streamed_response.get() + for p in model_response.parts: + if isinstance(p, ToolCallPart): + if tool := self._function_tools.get(p.tool_name): + tasks.append(asyncio.create_task(tool.run(deps, p, conv_messages), name=p.tool_name)) + else: + parts.append(self._unknown_tool(p.tool_name)) + + if received_text and not tasks and not parts: + # Can only get here if self._allow_text_result is False + self._incr_result_retry() + model_response = _messages.RetryPromptPart( + content='Plain text responses are not permitted, please call one of the functions instead.', + ) + return streamed_response.get(), [model_response] + + with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) + parts.extend(task_results) + return model_response, parts async def _validate_result( self, diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 6fd88ecf..5dd21482 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -282,6 +282,7 @@ class PartDeltaEvent: @dataclass class PartStopEvent: index: int + part: ModelResponsePart event_kind: Literal['part_stop'] = 'part_stop' diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 8950a679..17796443 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -7,11 +7,11 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterable, Iterator +from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager, contextmanager from datetime import datetime from functools import cache -from typing import TYPE_CHECKING, Literal, Union +from typing import TYPE_CHECKING, Literal import httpx @@ -129,7 +129,7 @@ async def request( @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}') # yield is required to make this a generator for type checking @@ -137,47 +137,7 @@ async def request_stream( yield # pragma: no cover -class StreamTextResponse(ABC): - """Streamed response from an LLM when returning text.""" - - def __aiter__(self) -> AsyncIterator[None]: - """Stream the response as an async iterable, building up the text as it goes. - - This is an async iterator that yields `None` to avoid doing the work of validating the input and - extracting the text field when it will often be thrown away. - """ - return self - - @abstractmethod - async def __anext__(self) -> None: - """Process the next chunk of the response, see above for why this returns `None`.""" - raise NotImplementedError() - - @abstractmethod - def get(self, *, final: bool = False) -> Iterable[str]: - """Returns an iterable of text since the last call to `get()` — e.g. the text delta. - - Args: - final: If True, this is the final call, after iteration is complete, the response should be fully validated - and all text extracted. - """ - raise NotImplementedError() - - @abstractmethod - def cost(self) -> Cost: - """Return the cost of the request. - - NOTE: this won't return the ful cost until the stream is finished. - """ - raise NotImplementedError() - - @abstractmethod - def timestamp(self) -> datetime: - """Get the timestamp of the response.""" - raise NotImplementedError() - - -class StreamStructuredResponse(ABC): +class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent | None]: # TODO: Should we drop the None? I think so @@ -218,9 +178,6 @@ def timestamp(self) -> datetime: raise NotImplementedError() -EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse] - - ALLOW_MODEL_REQUESTS = True """Whether to allow requests to models. diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 4c4c4384..670ca878 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -27,8 +27,8 @@ from ..tools import ToolDefinition from . import ( AgentModel, - EitherStreamedResponse, Model, + StreamedResponse, cached_async_http_client, check_allow_model_requests, ) @@ -165,7 +165,7 @@ async def request( @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamedResponse]: response = await self._messages_create(messages, True, model_settings) async with response: yield await self._process_streamed_response(response) @@ -230,14 +230,14 @@ def _process_response(response: AnthropicMessage) -> ModelResponse: return ModelResponse(items) @staticmethod - async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse: + async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse: """TODO: Process a streamed response, and prepare a streaming response to return.""" # We don't yet support streamed responses from Anthropic, so we raise an error here for now. # Streamed responses will be supported in a future release. raise RuntimeError('Streamed responses are not yet supported for Anthropic models.') - # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse + # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following: # RawMessageStartEvent # RawMessageDeltaEvent diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 007538f4..cca640df 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -28,7 +28,7 @@ ) from ..settings import ModelSettings from ..tools import ToolDefinition -from . import AgentModel, EitherStreamedResponse, Model, StreamStructuredResponse, StreamTextResponse +from . import AgentModel, Model, StreamedResponse @dataclass(init=False) @@ -160,57 +160,19 @@ async def request( @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamedResponse]: assert ( self.stream_function is not None ), 'FunctionModel must receive a `stream_function` to support streamed requests' response_stream = self.stream_function(messages, self.agent_info) - try: - first = await response_stream.__anext__() - except StopAsyncIteration as e: - raise ValueError('Stream function must return at least one item') from e - - if isinstance(first, str): - text_stream = cast(AsyncIterator[str], response_stream) - yield FunctionStreamTextResponse(first, text_stream) - else: - structured_stream = cast(AsyncIterator[DeltaToolCalls], response_stream) - yield FunctionStreamStructuredResponse(first, structured_stream) - - -@dataclass -class FunctionStreamTextResponse(StreamTextResponse): - """Implementation of `StreamTextResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" - - _next: str | None - _iter: AsyncIterator[str] - _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - _buffer: list[str] = field(default_factory=list, init=False) - - async def __anext__(self) -> None: - if self._next is not None: - self._buffer.append(self._next) - self._next = None - else: - self._buffer.append(await self._iter.__anext__()) - - def get(self, *, final: bool = False) -> Iterable[str]: - yield from self._buffer - self._buffer.clear() - - def cost(self) -> result.Cost: - return result.Cost() - - def timestamp(self) -> datetime: - return self._timestamp + yield FunctionStreamedResponse(response_stream) @dataclass -class FunctionStreamStructuredResponse(StreamStructuredResponse): - """Implementation of `StreamStructuredResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" +class FunctionStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" - _next: DeltaToolCalls | None - _iter: AsyncIterator[DeltaToolCalls] + _iter: AsyncIterator[str | DeltaToolCalls] _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict) _timestamp: datetime = field(default_factory=_utils.now_utc) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 72906428..3877ccc4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -32,9 +32,8 @@ from ..tools import ToolDefinition from . import ( AgentModel, - EitherStreamedResponse, Model, - StreamStructuredResponse, + StreamedResponse, StreamTextResponse, cached_async_http_client, check_allow_model_requests, @@ -180,7 +179,7 @@ async def request( @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamedResponse]: async with self._make_request(messages, True, model_settings) as http_response: yield await self._process_streamed_response(http_response) @@ -239,7 +238,7 @@ def _process_response(response: _GeminiResponse) -> ModelResponse: return _process_response_from_parts(parts) @staticmethod - async def _process_streamed_response(http_response: HTTPResponse) -> EitherStreamedResponse: + async def _process_streamed_response(http_response: HTTPResponse) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" aiter_bytes = http_response.aiter_bytes() start_response: _GeminiResponse | None = None @@ -262,7 +261,7 @@ async def _process_streamed_response(http_response: HTTPResponse) -> EitherStrea # TODO: Update this once we rework stream responses to be more flexible if _extract_response_parts(start_response).is_left(): - return GeminiStreamStructuredResponse(_content=content, _stream=aiter_bytes) + return GeminiStreamedResponse(_content=content, _stream=aiter_bytes) else: return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes) @@ -339,8 +338,8 @@ def timestamp(self) -> datetime: @dataclass -class GeminiStreamStructuredResponse(StreamStructuredResponse): - """Implementation of `StreamStructuredResponse` for the Gemini model.""" +class GeminiStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for the Gemini model.""" _content: bytearray _stream: AsyncIterator[bytes] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 20be8a33..e3565fca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -30,9 +30,8 @@ from ..tools import ToolDefinition from . import ( AgentModel, - EitherStreamedResponse, Model, - StreamStructuredResponse, + StreamedResponse, StreamTextResponse, cached_async_http_client, check_allow_model_requests, @@ -165,7 +164,7 @@ async def request( @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamedResponse]: response = await self._completions_create(messages, True, model_settings) async with response: yield await self._process_streamed_response(response) @@ -225,7 +224,7 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: return ModelResponse(items, timestamp=timestamp) @staticmethod - async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse: + async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" timestamp: datetime | None = None start_cost = Cost() @@ -244,7 +243,7 @@ async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) if delta.content is not None: return GroqStreamTextResponse(delta.content, response, timestamp, start_cost) elif delta.tool_calls is not None: - return GroqStreamStructuredResponse( + return GroqStreamedResponse( response, {c.index: c for c in delta.tool_calls}, timestamp, @@ -343,8 +342,8 @@ def timestamp(self) -> datetime: @dataclass -class GroqStreamStructuredResponse(StreamStructuredResponse): - """Implementation of `StreamStructuredResponse` for Groq models.""" +class GroqStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for Groq models.""" _response: AsyncStream[ChatCompletionChunk] _delta_tool_calls: dict[int, ChoiceDeltaToolCall] diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 639b9375..f967c123 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -31,9 +31,8 @@ from ..tools import ToolDefinition from . import ( AgentModel, - EitherStreamedResponse, Model, - StreamStructuredResponse, + StreamedResponse, StreamTextResponse, cached_async_http_client, ) @@ -164,7 +163,7 @@ async def request( @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" response = await self._stream_completions_create(messages, model_settings) async with response: @@ -295,7 +294,7 @@ def _process_response(response: MistralChatCompletionResponse) -> ModelResponse: async def _process_streamed_response( result_tools: list[ToolDefinition], response: MistralEventStreamAsync[MistralCompletionEvent], - ) -> EitherStreamedResponse: + ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" start_cost = Cost() @@ -323,7 +322,7 @@ async def _process_streamed_response( tool_calls = delta.tool_calls if tool_calls or content and result_tools: - return MistralStreamStructuredResponse( + return MistralStreamedResponse( {c.id if c.id else 'null': c for c in tool_calls or []}, {c.name: c for c in result_tools}, response, @@ -510,8 +509,8 @@ def timestamp(self) -> datetime: @dataclass -class MistralStreamStructuredResponse(StreamStructuredResponse): - """Implementation of `StreamStructuredResponse` for Mistral models.""" +class MistralStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for Mistral models.""" _function_tools: dict[str, MistralToolCall] _result_tools: dict[str, ToolDefinition] @@ -602,7 +601,7 @@ def _validate_required_json_shema(json_dict: dict[str, Any], json_schema: dict[s if isinstance(json_dict[param], dict) and 'properties' in param_schema: nested_schema = param_schema - if not MistralStreamStructuredResponse._validate_required_json_shema(json_dict[param], nested_schema): + if not MistralStreamedResponse._validate_required_json_shema(json_dict[param], nested_schema): return False return True diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index f852bef3..9f84cd8c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -36,8 +36,7 @@ from . import ( AgentModel, Model, - StreamStructuredResponse, - StreamTextResponse, + StreamedResponse, cached_async_http_client, check_allow_model_requests, ) @@ -166,10 +165,10 @@ async def request_stream_text( @asynccontextmanager async def request_stream_structured( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[StreamStructuredResponse]: + ) -> AsyncIterator[StreamedResponse]: response = await self._completions_create(messages, True, model_settings) async with response: - yield OpenAIStreamStructuredResponse(response) + yield OpenAIStreamedResponse(response) @overload async def _completions_create( @@ -320,8 +319,8 @@ def timestamp(self) -> datetime: @dataclass -class OpenAIStreamStructuredResponse(StreamStructuredResponse): - """Implementation of `StreamStructuredResponse` for OpenAI models.""" +class OpenAIStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for OpenAI models.""" _response: AsyncStream[ChatCompletionChunk] diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 41670914..916f902e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -26,9 +26,8 @@ from ..tools import ToolDefinition from . import ( AgentModel, - EitherStreamedResponse, Model, - StreamStructuredResponse, + StreamedResponse, StreamTextResponse, ) @@ -137,7 +136,7 @@ async def request( @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> AsyncIterator[EitherStreamedResponse]: + ) -> AsyncIterator[StreamedResponse]: msg = self._request(messages, model_settings) cost = Cost() @@ -155,7 +154,7 @@ async def request_stream( if texts: yield TestStreamTextResponse('\n\n'.join(texts), cost) else: - yield TestStreamStructuredResponse(msg, cost) + yield TestStreamedResponse(msg, cost) def gen_tool_args(self, tool_def: ToolDefinition) -> Any: return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() @@ -242,7 +241,7 @@ def timestamp(self) -> datetime: @dataclass -class TestStreamStructuredResponse(StreamStructuredResponse): +class TestStreamedResponse(StreamedResponse): """A structured response that streams test data.""" _structured_response: ModelResponse diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index d15961d7..8e5c8688 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,14 +1,16 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime -from typing import Generic, TypeVar, cast +from typing import Generic, TypeVar import logfire_api from . import _result, _utils, exceptions, messages as _messages, models +from .messages import PartDeltaEvent, PartStartEvent, TextPart, TextPartDelta from .tools import AgentDeps __all__ = ( @@ -118,7 +120,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat cost_so_far: Cost """Cost of the run up until the last request.""" - _stream_response: models.EitherStreamedResponse + _stream_response: models.StreamedResponse _result_schema: _result.ResultSchema[ResultData] | None _deps: AgentDeps _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] @@ -149,14 +151,8 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu Returns: An async iterable of the response data. """ - # TODO: Drop the following isinstance and use stream_structured even for text (and make that work) - # Rename stream_structured to stream_responses or similar (the idea is just that you aren't getting diffs) - if isinstance(self._stream_response, models.StreamTextResponse): - async for text in self.stream_text(debounce_by=debounce_by): - yield cast(ResultData, text) - else: - async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): - yield await self.validate_structured_result(structured_message, allow_partial=not is_last) + async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): + yield await self.validate_structured_result(structured_message, allow_partial=not is_last) async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. @@ -175,40 +171,41 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ + + async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: + async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async for group in group_iter: + for maybe_event in group: + if ( + isinstance(maybe_event, PartStartEvent) + and isinstance(maybe_event.part, TextPart) + and maybe_event.part.content + ): + yield maybe_event.part.content, maybe_event.index + elif ( + isinstance(maybe_event, PartDeltaEvent) + and isinstance(maybe_event.delta, TextPartDelta) + and maybe_event.delta.content_delta + ): + yield maybe_event.delta.content_delta, maybe_event.index + with _logfire.span('response stream text') as lf_span: - if isinstance(self._stream_response, models.StreamStructuredResponse): - raise exceptions.UserError('stream_text() can only be used with text responses') if delta: - async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for _ in group_iter: - yield ''.join(self._stream_response.get()) - final_delta = ''.join(self._stream_response.get(final=True)) - if final_delta: - yield final_delta + async for text, _ in _stream_text_deltas(): + yield text else: # a quick benchmark shows it's faster to build up a string with concat when we're # yielding at each step - chunks: list[str] = [] - combined = '' - async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for _ in group_iter: - new = False - for chunk in self._stream_response.get(): - chunks.append(chunk) - new = True - if new: - combined = await self._validate_text_result(''.join(chunks)) - yield combined - - new = False - for chunk in self._stream_response.get(final=True): - chunks.append(chunk) - new = True - if new: - combined = await self._validate_text_result(''.join(chunks)) - yield combined - lf_span.set_attribute('combined_text', combined) - await self._marked_completed(_messages.ModelResponse.from_text(combined)) + chunks: dict[int, str] = defaultdict(str) + combined_validated_text = '' + async for text, index in _stream_text_deltas(): + chunks[index] += text + combined_text = ''.join([chunks[k] for k in sorted(chunks)]) + combined_validated_text = await self._validate_text_result(combined_text) + yield combined_validated_text + + lf_span.set_attribute('combined_text', combined_validated_text) + await self._marked_completed(_messages.ModelResponse.from_text(combined_validated_text)) async def stream_structured( self, *, debounce_by: float | None = 0.1 @@ -228,45 +225,36 @@ async def stream_structured( An async iterable of the structured response message and whether that is the last message. """ with _logfire.span('response stream structured') as lf_span: - if isinstance(self._stream_response, models.StreamTextResponse): - raise exceptions.UserError('stream_structured() can only be used with structured responses') - else: - # we should already have a message at this point, yield that first if it has any content - msg = self._stream_response.get() - for item in msg.parts: - if isinstance(item, _messages.ToolCallPart) and item.has_content(): - yield msg, False - break - async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for _ in group_iter: - msg = self._stream_response.get() - for item in msg.parts: - if isinstance(item, _messages.ToolCallPart) and item.has_content(): - yield msg, False - break - msg = self._stream_response.get(final=True) - yield msg, True - lf_span.set_attribute('structured_response', msg) - await self._marked_completed(msg) + # we should already have a message at this point, yield that first if it has any content + msg = self._stream_response.get() + for item in msg.parts: + if isinstance(item, _messages.ToolCallPart) and item.has_content(): + yield msg, False + break + async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async for _ in group_iter: + msg = self._stream_response.get() + for item in msg.parts: + if isinstance(item, _messages.ToolCallPart) and item.has_content(): + yield msg, False + break + msg = self._stream_response.get(final=True) + yield msg, True + lf_span.set_attribute('structured_response', msg) + await self._marked_completed(msg) async def get_data(self) -> ResultData: """Stream the whole response, validate and return it.""" async for _ in self._stream_response: pass - if isinstance(self._stream_response, models.StreamTextResponse): - text = ''.join(self._stream_response.get(final=True)) - text = await self._validate_text_result(text) - await self._marked_completed(_messages.ModelResponse.from_text(text)) - return cast(ResultData, text) - else: - message = self._stream_response.get(final=True) - await self._marked_completed(message) - return await self.validate_structured_result(message) + message = self._stream_response.get(final=True) + await self._marked_completed(message) + return await self.validate_structured_result(message) @property def is_structured(self) -> bool: """Return whether the stream response contains structured data (as opposed to text).""" - return isinstance(self._stream_response, models.StreamStructuredResponse) + return isinstance(self._stream_response, models.StreamedResponse) def cost(self) -> Cost: """Return the cost of the whole run. diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 86532dee..cd297e07 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -52,7 +52,7 @@ from pydantic_ai.models.mistral import ( MistralAgentModel, MistralModel, - MistralStreamStructuredResponse, + MistralStreamedResponse, ) pytestmark = [ @@ -1694,5 +1694,5 @@ def test_generate_user_output_format_multiple(): ], ) def test_validate_required_json_shema(desc: str, schema: dict[str, Any], data: dict[str, Any], expected: bool) -> None: - result = MistralStreamStructuredResponse._validate_required_json_shema(data, schema) # pyright: ignore[reportPrivateUsage] + result = MistralStreamedResponse._validate_required_json_shema(data, schema) # pyright: ignore[reportPrivateUsage] assert result == expected, f'{desc} — expected {expected}, got {result}' From 2a01308bf017e34bfb63350f48978b467db019c5 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 3 Jan 2025 16:31:22 -0700 Subject: [PATCH 03/34] WIP --- docs/api/models/base.md | 1 - docs/models.md | 1 - pydantic_ai_slim/pydantic_ai/_utils.py | 17 ++-- pydantic_ai_slim/pydantic_ai/agent.py | 4 +- pydantic_ai_slim/pydantic_ai/messages.py | 5 +- .../pydantic_ai/models/__init__.py | 13 +-- .../pydantic_ai/models/function.py | 59 ++++++++++++-- pydantic_ai_slim/pydantic_ai/models/test.py | 78 ++++++++++++------ pydantic_ai_slim/pydantic_ai/result.py | 81 ++++++++++--------- tests/test_streaming.py | 42 +++++----- tests/test_utils.py | 12 +-- 11 files changed, 196 insertions(+), 117 deletions(-) diff --git a/docs/api/models/base.md b/docs/api/models/base.md index 9f847e9a..bf72de7e 100644 --- a/docs/api/models/base.md +++ b/docs/api/models/base.md @@ -7,7 +7,6 @@ - Model - AgentModel - AbstractToolDefinition - - StreamTextResponse - StreamedResponse - ALLOW_MODEL_REQUESTS - check_allow_model_requests diff --git a/docs/models.md b/docs/models.md index 95e1935a..f4bcf029 100644 --- a/docs/models.md +++ b/docs/models.md @@ -455,7 +455,6 @@ To implement support for models not already supported, you will need to subclass This in turn will require you to implement the following other abstract base classes: * [`AgentModel`][pydantic_ai.models.AgentModel] -* [`StreamTextResponse`][pydantic_ai.models.StreamTextResponse] * [`StreamedResponse`][pydantic_ai.models.StreamedResponse] The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py). diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 6b87a082..751e92d9 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -137,7 +137,7 @@ def __repr__(self): @asynccontextmanager async def group_by_temporal( aiter: AsyncIterator[T], soft_max_interval: float | None -) -> AsyncIterator[AsyncIterable[list[T]]]: +) -> AsyncIterator[AsyncIterable[tuple[list[T], bool]]]: """Group items from an async iterable into lists based on time interval between them. Effectively debouncing the iterator. @@ -160,13 +160,15 @@ async def group_by_temporal( as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed Returns: - A context manager usable as an iterator async iterable of lists of items from the input async iterable. + A context manager usable as an iterator async iterable of pairs of lists of items from the input async iterable, + and a boolean indicating whether the item was final coming out of the iterator. """ if soft_max_interval is None: - async def async_iter_groups_noop() -> AsyncIterator[list[T]]: + async def async_iter_groups_noop() -> AsyncIterator[tuple[list[T], bool]]: async for item in aiter: - yield [item] + yield [item], False + yield [], True yield async_iter_groups_noop() return @@ -174,7 +176,7 @@ async def async_iter_groups_noop() -> AsyncIterator[list[T]]: # we might wait for the next item more than once, so we store the task to await next time task: asyncio.Task[T] | None = None - async def async_iter_groups() -> AsyncIterator[list[T]]: + async def async_iter_groups() -> AsyncIterator[tuple[list[T], bool]]: nonlocal task assert soft_max_interval is not None and soft_max_interval >= 0, 'soft_max_interval must be a positive number' @@ -204,8 +206,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]: item = done.pop().result() except StopAsyncIteration: # if the task raised StopAsyncIteration, we're done iterating - if buffer: - yield buffer + yield buffer, True task = None break else: @@ -217,7 +218,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]: group_start_time = time.monotonic() elif buffer: # otherwise if the task timeout expired and we have items in the buffer, yield the buffer - yield buffer + yield buffer, False # clear the buffer and reset the group start time ready for the next group buffer = [] group_start_time = None diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e24d27db..bd463db3 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -427,7 +427,7 @@ async def main(): model_req_span.__exit__(None, None, None) with _logfire.span('handle model response') as handle_span: - maybe_final_result = await self._handle_streamed_model_response(model_response, run_context) + maybe_final_result = await self._handle_streamed_response(model_response, run_context) # Check if we got a final result if isinstance(maybe_final_result, _MarkFinalResult): @@ -997,7 +997,7 @@ async def _process_function_tools( parts.extend(task_results) return parts - async def _handle_streamed_model_response( + async def _handle_streamed_response( self, streamed_response: models.StreamedResponse, run_context: RunContext[AgentDeps], diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index cbe914a3..2fafceae 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -153,6 +153,9 @@ class TextPart: part_kind: Literal['text'] = 'text' """Part type identifier, this is available on all parts as a discriminator.""" + def has_content(self) -> bool: + return bool(self.content) + @dataclass class ArgsJson: @@ -293,7 +296,7 @@ def apply(self, part: ToolCallPart) -> ToolCallPart: class PartStartEvent: """If multiple PartStartEvents are received with the same index, the new one should fully replace the old one.""" - index: int + index: int # TODO: Consider replacing index here and below with part_id part: ModelResponsePart event_kind: Literal['part_start'] = 'part_start' diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 3fd18139..f9c45e2a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -132,6 +132,7 @@ async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None ) -> AsyncIterator[StreamedResponse]: """Make a request to the model and return a streaming response.""" + # This method is not required, but you need to implement it if you want to support streamed responses raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}') # yield is required to make this a generator for type checking # noinspection PyUnreachableCode @@ -141,17 +142,17 @@ async def request_stream( class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" - def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent | None]: # TODO: Should we drop the None? I think so - """Stream the response as an async iterable, building up the tool call as it goes. + def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + """Stream the response as an async iterable of (optional) `ModelResponseStreamEvent`s. - This is an async iterator that yields `None` to avoid doing the work of building the final tool call when - it will often be thrown away. + This is an async iterator that yields events as they are received. It may yield `None` when raw data is received + from the model but there is not enough information to produce a meaningful ModelResponseStreamEvent. """ return self @abstractmethod - async def __anext__(self) -> ModelResponseStreamEvent | None: # TODO: Should we drop the None? I think so - """Process the next chunk of the response, see above for why this returns `None`.""" + async def __anext__(self) -> ModelResponseStreamEvent | None: + """Process the next chunk of the response, see above for why this may return `None`.""" raise NotImplementedError() @abstractmethod diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index f1724737..02b5610b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -18,10 +18,14 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + PartDeltaEvent, + PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, + TextPartDelta, ToolCallPart, + ToolCallPartDelta, ToolReturnPart, UserPromptPart, ) @@ -167,24 +171,61 @@ async def request_stream( yield FunctionStreamedResponse(response_stream) +_CONTENT_PART_INDEX = -1 + + @dataclass class FunctionStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" _iter: AsyncIterator[str | DeltaToolCalls] - _delta_tool_calls: dict[int, DeltaToolCall] = field(default_factory=dict) _timestamp: datetime = field(default_factory=_utils.now_utc) + _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + async def __anext__(self) -> ModelResponseStreamEvent | None: - raise NotImplementedError # TODO: Need to implement this... + if self._event_iterator is None: + self._event_iterator = self._get_event_iterator() + return await self._event_iterator.__anext__() + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + async for item in self._iter: + if isinstance(item, str): + text = item + content_part = self._parts.get(_CONTENT_PART_INDEX) + if content_part is None: + content_part = TextPart(content=text) + self._parts[_CONTENT_PART_INDEX] = content_part + yield PartStartEvent(index=_CONTENT_PART_INDEX, part=content_part) + else: + assert isinstance(content_part, TextPart), 'Cannot switch to text part mid-stream' + delta = TextPartDelta(content_delta=text) + self._parts[_CONTENT_PART_INDEX] = delta.apply(content_part) + yield PartDeltaEvent(index=_CONTENT_PART_INDEX, delta=delta) + else: + delta_tool_calls = item + for index, delta_tool_call in delta_tool_calls.items(): + existing_part = self._parts.get(index) + if existing_part is None: + assert ( + delta_tool_call.name is not None + ), 'The first delta_tool_call with a given index must include a tool name' + part = ToolCallPart.from_raw_args( + tool_name=delta_tool_call.name, args=delta_tool_call.json_args or '' + ) + self._parts[index] = part + yield PartStartEvent(index=index, part=part) + else: + assert isinstance(existing_part, ToolCallPart), 'Cannot switch to tool call part mid-stream' + if delta_tool_call.json_args is not None: + delta = ToolCallPartDelta(delta_tool_call.json_args) + self._parts[index] = delta.apply(existing_part) + yield PartDeltaEvent(index=index, delta=delta) def get(self, *, final: bool = False) -> ModelResponse: - calls: list[ModelResponsePart] = [] - for c in self._delta_tool_calls.values(): - if c.name is not None and c.json_args is not None: - calls.append(ToolCallPart.from_raw_args(c.name, c.json_args)) - - return ModelResponse(calls, timestamp=self._timestamp) + parts = [self._parts[index] for index in sorted(self._parts)] + return ModelResponse(parts, timestamp=self._timestamp) def usage(self) -> result.Usage: return _estimate_usage([self.get()]) @@ -229,4 +270,6 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage: def _estimate_string_usage(content: str) -> int: + if not content: + return 0 return len(re.split(r'[\s",.:]+', content)) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index b1c5c841..1289e2b5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -2,14 +2,13 @@ import re import string -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import date, datetime, timedelta from typing import Any, Literal import pydantic_core -from typing_extensions import assert_never from .. import _utils from ..messages import ( @@ -17,8 +16,12 @@ ModelRequest, ModelResponse, ModelResponsePart, + ModelResponseStreamEvent, + PartDeltaEvent, + PartStartEvent, RetryPromptPart, TextPart, + TextPartDelta, ToolCallPart, ToolReturnPart, ) @@ -30,7 +33,7 @@ Model, StreamedResponse, ) -from .function import _estimate_usage # pyright: ignore[reportPrivateUsage] +from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage] @dataclass @@ -142,23 +145,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: msg = self._request(messages, model_settings) usage = _estimate_usage(messages) - - # TODO: Rework this once we make StreamTextResponse more general - texts: list[str] = [] - tool_calls: list[ToolCallPart] = [] - for item in msg.parts: - if isinstance(item, TextPart): - texts.append(item.content) - elif isinstance(item, ToolCallPart): - tool_calls.append(item) - else: - assert_never(item) - - if texts: - # yield TestStreamTextResponse('\n\n'.join(texts), usage) - raise NotImplementedError('TODO: Fix this branch') - else: - yield TestStreamedResponse(msg, usage) + yield TestStreamedResponse(msg, usage) def gen_tool_args(self, tool_def: ToolDefinition) -> Any: return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() @@ -227,14 +214,48 @@ class TestStreamedResponse(StreamedResponse): _structured_response: ModelResponse _usage: Usage - _iter: Iterator[None] = field(default_factory=lambda: iter([None])) _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - async def __anext__(self) -> None: - return _utils.sync_anext(self._iter) + _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + + async def __anext__(self) -> ModelResponseStreamEvent | None: + if not self._parts: + print('---') + if self._event_iterator is None: + self._event_iterator = self._get_event_iterator() + next_event = await self._event_iterator.__anext__() + existing_part = self._parts.get(next_event.index) + if isinstance(next_event, PartDeltaEvent): + assert existing_part is not None, 'PartDeltaEvent without existing part' + self._parts[next_event.index] = next_event.delta.apply(existing_part) + else: + self._parts[next_event.index] = next_event.part + + self._usage += _estimate_event_usage(next_event) + return next_event + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + for i, part in enumerate(self._structured_response.parts): + if isinstance(part, TextPart): + text = part.content + *words, last_word = text.split(' ') + words = [f'{word} ' for word in words] + words.append(last_word) + if len(words) == 1 and len(text) > 2: + mid = len(text) // 2 + words = [text[:mid], text[mid:]] + yield PartStartEvent(index=i, part=TextPart(content='')) + for word in words: + yield PartDeltaEvent(index=i, delta=TextPartDelta(content_delta=word)) + # yield PartStopEvent(index=i, part=part) + else: + yield PartStartEvent(index=i, part=part) + # yield PartStopEvent(index=i, part=part) def get(self, *, final: bool = False) -> ModelResponse: - return self._structured_response + parts = [self._parts[index] for index in sorted(self._parts)] + return ModelResponse(parts, timestamp=self._timestamp) def usage(self) -> Usage: return self._usage @@ -397,3 +418,12 @@ def _char(self) -> str: rem //= chars s += _chars[self.seed % chars] return s + + +def _estimate_event_usage(event: ModelResponseStreamEvent | None) -> Usage: + response_tokens = 0 + if isinstance(event, PartStartEvent) and isinstance(event.part, TextPart): + response_tokens = _estimate_string_usage(event.part.content) + elif isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta): + response_tokens = _estimate_string_usage(event.delta.content_delta) + return Usage(response_tokens=response_tokens, total_tokens=response_tokens) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 666e4681..544e8855 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -6,12 +6,13 @@ from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime -from typing import Generic, Union +from typing import Generic, Union, cast import logfire_api from typing_extensions import TypeVar from . import _result, _utils, exceptions, messages as _messages, models +from .messages import ModelResponseStreamEvent, TextPart from .tools import AgentDeps, RunContext from .usage import Usage, UsageLimits @@ -227,7 +228,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for group in group_iter: + async for group, _is_final in group_iter: for maybe_event in group: if ( isinstance(maybe_event, _messages.PartStartEvent) @@ -265,10 +266,6 @@ async def stream_structured( ) -> AsyncIterator[tuple[_messages.ModelResponse, bool]]: """Stream the response as an async iterable of Structured LLM Messages. - !!! note - This method will fail if the response is text, - e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `False`. - Args: debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. Debouncing is particularly important for long structured responses to reduce the overhead of @@ -282,23 +279,21 @@ async def stream_structured( ) with _logfire.span('response stream structured') as lf_span: - # we should already have a message at this point, yield that first if it has any content + # if the message currently has any parts with content, yield before streaming msg = self._stream_response.get() - for item in msg.parts: - if isinstance(item, _messages.ToolCallPart) and item.has_content(): + for part in msg.parts: + if part.has_content(): yield msg, False break + async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for _ in group_iter: - msg = self._stream_response.get() - for item in msg.parts: - if isinstance(item, _messages.ToolCallPart) and item.has_content(): - yield msg, False - break - msg = self._stream_response.get(final=True) - yield msg, True - lf_span.set_attribute('structured_response', msg) - await self._marked_completed(msg) + async for events, is_final in group_iter: + msg = self._stream_response.get(final=is_final) + yield msg, is_final + if is_final: + # TODO: Should this now be `final_response` instead of `structured_response`? + lf_span.set_attribute('structured_response', msg) + await self._marked_completed(msg) async def get_data(self) -> ResultData: """Stream the whole response, validate and return it.""" @@ -312,11 +307,6 @@ async def get_data(self) -> ResultData: await self._marked_completed(message) return await self.validate_structured_result(message) - @property - def is_structured(self) -> bool: - """Return whether the stream response contains structured data (as opposed to text).""" - return isinstance(self._stream_response, models.StreamedResponse) - def usage(self) -> Usage: """Return the usage of the whole run. @@ -333,20 +323,29 @@ async def validate_structured_result( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> ResultData: """Validate a structured result message.""" - assert self._result_schema is not None, 'Expected _result_schema to not be None' - assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None' - match = self._result_schema.find_named_tool(message.parts, self._result_tool_name) - if match is None: - raise exceptions.UnexpectedModelBehavior( - f'Invalid message, unable to find tool: {self._result_schema.tool_names()}' - ) - - call, result_tool = match - result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) - - for validator in self._result_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data + if self._result_schema is not None and self._result_tool_name is not None: + match = self._result_schema.find_named_tool(message.parts, self._result_tool_name) + if match is None: + raise exceptions.UnexpectedModelBehavior( + f'Invalid response, unable to find tool: {self._result_schema.tool_names()}' + ) + + call, result_tool = match + result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) + + for validator in self._result_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) + return result_data + else: + text = '\n\n'.join(x.content for x in message.parts if isinstance(x, TextPart)) + for validator in self._result_validators: + text = await validator.validate( + text, # pyright: ignore[reportArgumentType] + None, + self._run_ctx, + ) + # Since there is no result tool, we can assume that str is compatible with ResultData + return cast(ResultData, text) async def _validate_text_result(self, text: str) -> str: for validator in self._result_validators: @@ -364,8 +363,10 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: def _get_usage_checking_stream_response( - stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage] -) -> AsyncIterator[ResultData]: + stream_response: AsyncIterator[ModelResponseStreamEvent | None], + limits: UsageLimits | None, + get_usage: Callable[[], Usage], +) -> AsyncIterator[ModelResponseStreamEvent | None]: if limits is not None and limits.has_token_limits(): async def _usage_checking_iterator(): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 130d8d6b..6163da80 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -8,7 +8,7 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai import Agent, UnexpectedModelBehavior, capture_run_messages from pydantic_ai.messages import ( ArgsDict, ArgsJson, @@ -41,7 +41,6 @@ async def ret_a(x: str) -> str: async with test_agent.run_stream('Hello') as result: assert test_agent.name == 'test_agent' - assert not result.is_structured assert not result.is_complete assert result.all_messages() == snapshot( [ @@ -66,14 +65,6 @@ async def ret_a(x: str) -> str: response = await result.get_data() assert response == snapshot('{"ret_a":"a-apple"}') assert result.is_complete - assert result.usage() == snapshot( - Usage( - requests=2, - request_tokens=103, - response_tokens=11, - total_tokens=114, - ) - ) assert result.timestamp() == IsNow(tz=timezone.utc) assert result.all_messages() == snapshot( [ @@ -88,6 +79,14 @@ async def ret_a(x: str) -> str: ModelResponse.from_text(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)), ] ) + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=103, + response_tokens=11, + total_tokens=114, + ) + ) async def test_streamed_structured_response(): @@ -97,7 +96,6 @@ async def test_streamed_structured_response(): async with agent.run_stream('') as result: assert agent.name == 'fig_jam' - assert result.is_structured assert not result.is_complete response = await result.get_data() assert response == snapshot(('a', 'a')) @@ -124,10 +122,11 @@ async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> A assert chunks == snapshot([[1], [1, 2, 3, 4], [1, 2, 3, 4]]) - async with agent.run_stream('Hello') as result: - with pytest.raises(UserError, match=r'stream_text\(\) can only be used with text responses'): - async for _ in result.stream_text(): - pass + # TODO: Is the following check still relevant? I'm not sure what it's trying to do... + # async with agent.run_stream('Hello') as result: + # with pytest.raises(UserError, match=r'stream_text\(\) can only be used with text responses'): + # async for _ in result.stream_text(): + # pass async def test_streamed_text_stream(): @@ -136,7 +135,6 @@ async def test_streamed_text_stream(): agent = Agent(m) async with agent.run_stream('Hello') as result: - assert not result.is_structured # typehint to test (via static typing) that the stream type is correctly inferred chunks: list[str] = [c async for c in result.stream()] # one chunk due to group_by_temporal @@ -152,6 +150,9 @@ async def test_streamed_text_stream(): 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.', + # This last value is handled twice due to the debounce_by=None combined with the need to emit + # a final empty chunk to signal the end of the stream + 'The cat sat on the mat.', ] ) @@ -160,10 +161,11 @@ async def test_streamed_text_stream(): ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] ) - async with agent.run_stream('Hello') as result: - with pytest.raises(UserError, match=r'stream_structured\(\) can only be used with structured responses'): - async for _ in result.stream_structured(): - pass + # TODO: Is the following check still relevant? I'm not sure what it's trying to do... + # async with agent.run_stream('Hello') as result: + # with pytest.raises(UserError, match=r'stream_structured\(\) can only be used with structured responses'): + # async for _ in result.stream_structured(): + # pass async def test_plain_response(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 0bcf88ab..96c79203 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,11 +15,11 @@ @pytest.mark.parametrize( 'interval,expected', [ - (None, snapshot([[1], [2], [3]])), - (0, snapshot([[1], [2], [3]])), - (0.02, snapshot([[1], [2], [3]])), - (0.04, snapshot([[1, 2], [3]])), - (0.1, snapshot([[1, 2, 3]])), + (None, snapshot([([1], False), ([2], False), ([3], False), ([], True)])), + (0, snapshot([([1], False), ([2], False), ([3], False), ([], True)])), + (0.02, snapshot([([1], False), ([2], False), ([3], False), ([], True)])), + (0.04, snapshot([([1, 2], False), ([3], True)])), + (0.1, snapshot([([1, 2, 3], True)])), ], ) async def test_group_by_temporal(interval: float | None, expected: list[list[int]]): @@ -32,7 +32,7 @@ async def yield_groups() -> AsyncIterator[int]: await asyncio.sleep(0.02) async with group_by_temporal(yield_groups(), soft_max_interval=interval) as groups_iter: - groups: list[list[int]] = [g async for g in groups_iter] + groups: list[tuple[list[int], bool]] = [g async for g in groups_iter] assert groups == expected From 807c2e59236a37fcbcc3ed931d3bcded5d1f6a45 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sun, 5 Jan 2025 11:04:33 -0700 Subject: [PATCH 04/34] Get streaming tests passing --- pydantic_ai_slim/pydantic_ai/agent.py | 2 + pydantic_ai_slim/pydantic_ai/messages.py | 2 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 68 ++++++++++++++++--- tests/test_streaming.py | 2 +- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index bd463db3..db12fa86 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1029,6 +1029,8 @@ async def _handle_streamed_response( tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] parts: list[_messages.ModelRequestPart] = [] model_response = streamed_response.get() + if not model_response.parts: + raise exceptions.UnexpectedModelBehavior('Received empty model response') for p in model_response.parts: if isinstance(p, ToolCallPart): if tool := self._function_tools.get(p.tool_name): diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 2fafceae..15f73cda 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -293,7 +293,7 @@ def apply(self, part: ToolCallPart) -> ToolCallPart: @dataclass -class PartStartEvent: +class PartStartEvent: # TODO: Consider renaming to PartReplaceEvent, or somehow indicate full replacement is an option """If multiple PartStartEvents are received with the same index, the new one should fully replace the old one.""" index: int # TODO: Consider replacing index here and below with part_id diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 3cd8bd01..c63d4362 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -19,6 +19,8 @@ ModelRequest, ModelResponse, ModelResponsePart, + ModelResponseStreamEvent, + PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, @@ -256,11 +258,7 @@ async def _process_streamed_response(http_response: HTTPResponse) -> StreamedRes if start_response is None: raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') - # TODO: Update this once we rework stream responses to be more flexible - if _extract_response_parts(start_response).is_left(): - return GeminiStreamedResponse(_content=content, _stream=aiter_bytes) - else: - raise NotImplementedError('TODO: delete this branch') + return GeminiStreamedResponse(_content=content, _stream=aiter_bytes) @classmethod def _message_to_gemini_content( @@ -307,9 +305,63 @@ class GeminiStreamedResponse(StreamedResponse): _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _usage: result.Usage = field(default_factory=result.Usage, init=False) - async def __anext__(self) -> None: - chunk = await self._stream.__anext__() - self._content.extend(chunk) + _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + + async def __anext__(self) -> ModelResponseStreamEvent | None: + if self._event_iterator is None: + self._event_iterator = self._get_event_iterator() + next_event = await self._event_iterator.__anext__() + return next_event + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + chunk_index = -1 + current_gemini_response_index = -1 + current_gemini_response_part_index = -1 + last_gemini_part_data: tuple[int, int, _GeminiPartUnion] | None = None + + async for chunk in self._stream: + chunk_index += 1 + next_part_index = -1 + self._content.extend(chunk) + + responses = _gemini_streamed_response_ta.validate_json( + self._content, + experimental_allow_partial='trailing-strings', + ) + + r: _GeminiResponse + for response_index, r in enumerate(responses): + candidate = r['candidates'][0] + if response_index < current_gemini_response_index: + next_part_index += len(candidate['content']['parts']) + continue + + if response_index > current_gemini_response_index: + current_gemini_response_index = response_index + current_gemini_response_part_index = -1 + + for response_part_index, gemini_part in enumerate(candidate['content']['parts']): + next_part_index += 1 + if response_part_index < current_gemini_response_part_index: + continue + current_gemini_response_part_index = response_part_index + + # Don't yield the same part twice + gemini_part_data = (response_index, response_part_index, gemini_part) + if gemini_part_data == last_gemini_part_data: + yield None # TODO: Should be able to safely drop this if we drop yielding None + continue + last_gemini_part_data = gemini_part_data + + if 'text' in gemini_part: + part = TextPart(gemini_part['text']) + yield PartStartEvent(index=next_part_index, part=part) + elif 'function_call' in gemini_part: + part = ToolCallPart.from_raw_args(gemini_part['name'], gemini_part['args']) + yield PartStartEvent(index=next_part_index, part=part) + else: # pragma: no branch + # We don't currently do anything with function_response parts + assert 'function_response' in gemini_part, f'Unhandled part: {gemini_part}' def get(self, *, final: bool = False) -> ModelResponse: """Get the `ModelResponse` at this point. diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 6163da80..e8a052bc 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -273,7 +273,7 @@ async def stream_structured_function(_messages: list[ModelMessage], _: AgentInfo agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) - with pytest.raises(UnexpectedModelBehavior, match='Received empty tool call message'): + with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): async with agent.run_stream('hello'): pass From 28b5d9d237b75c5f2acdf17d4f156b68e9429ff5 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sun, 5 Jan 2025 12:10:39 -0700 Subject: [PATCH 05/34] Get gemini streaming working --- pydantic_ai_slim/pydantic_ai/messages.py | 1 - pydantic_ai_slim/pydantic_ai/models/gemini.py | 123 ++++++++++-------- pydantic_ai_slim/pydantic_ai/result.py | 20 ++- tests/models/test_gemini.py | 16 ++- tests/test_streaming.py | 2 +- 5 files changed, 98 insertions(+), 64 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 15f73cda..9e68e6a8 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -315,7 +315,6 @@ class PartStopEvent: """A part stop event.""" index: int - part: ModelResponsePart event_kind: Literal['part_stop'] = 'part_stop' diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c63d4362..c0a5f468 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -20,10 +20,12 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, + PartDeltaEvent, PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, + TextPartDelta, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -305,63 +307,92 @@ class GeminiStreamedResponse(StreamedResponse): _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _usage: result.Usage = field(default_factory=result.Usage, init=False) + _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent | None: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() next_event = await self._event_iterator.__anext__() + + # Update the `_parts` so the `get` method can return the correct value + if isinstance(next_event, PartStartEvent): + self._parts[next_event.index] = next_event.part + elif isinstance(next_event, PartDeltaEvent): + self._parts[next_event.index] = next_event.delta.apply(self._parts[next_event.index]) + return next_event async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: - chunk_index = -1 - current_gemini_response_index = -1 - current_gemini_response_part_index = -1 - last_gemini_part_data: tuple[int, int, _GeminiPartUnion] | None = None + current_part_index: int | None = None + current_tool_call_name: str | None = None # None means we are in a text part or have no parts at all + + async for gemini_response in self._get_gemini_responses(): + candidate = gemini_response['candidates'][0] + gemini_part: _GeminiPartUnion + for gemini_part in candidate['content']['parts']: + if 'text' in gemini_part: + # The following condition holds if and only if we are not already in a text part: + if current_part_index is None or current_tool_call_name is not None: + current_tool_call_name = None + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + part = TextPart(gemini_part['text']) + yield PartStartEvent(index=current_part_index, part=part) + else: + delta = TextPartDelta(gemini_part['text']) + yield PartDeltaEvent(index=current_part_index, delta=TextPartDelta(gemini_part['text'])) + elif 'function_call' in gemini_part: + # Here, we assume all function_call parts are complete and don't have deltas. + # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly + # it would just be a bit more complicated. And we'd need to confirm the intended semantics. + current_tool_call_name = gemini_part['function_call']['name'] + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + yield PartStartEvent( + index=current_part_index, + part=ToolCallPart.from_raw_args(current_tool_call_name, gemini_part['function_call']['args']), + ) + else: + assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}' + + async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: + # This method exists to ensure we only yield completed items, so we don't need to worry about + # partial gemini responses, which would make everything more complicated + + gemini_responses: list[_GeminiResponse] = [] + current_gemini_response_index = 0 + # Right now, there are some circumstances where we will have information that could be yielded sooner than it is + # But changing that would make things a lot more complicated. async for chunk in self._stream: - chunk_index += 1 - next_part_index = -1 self._content.extend(chunk) - responses = _gemini_streamed_response_ta.validate_json( + gemini_responses = _gemini_streamed_response_ta.validate_json( self._content, experimental_allow_partial='trailing-strings', ) - r: _GeminiResponse - for response_index, r in enumerate(responses): - candidate = r['candidates'][0] - if response_index < current_gemini_response_index: - next_part_index += len(candidate['content']['parts']) - continue - - if response_index > current_gemini_response_index: - current_gemini_response_index = response_index - current_gemini_response_part_index = -1 - - for response_part_index, gemini_part in enumerate(candidate['content']['parts']): - next_part_index += 1 - if response_part_index < current_gemini_response_part_index: - continue - current_gemini_response_part_index = response_part_index - - # Don't yield the same part twice - gemini_part_data = (response_index, response_part_index, gemini_part) - if gemini_part_data == last_gemini_part_data: - yield None # TODO: Should be able to safely drop this if we drop yielding None - continue - last_gemini_part_data = gemini_part_data - - if 'text' in gemini_part: - part = TextPart(gemini_part['text']) - yield PartStartEvent(index=next_part_index, part=part) - elif 'function_call' in gemini_part: - part = ToolCallPart.from_raw_args(gemini_part['name'], gemini_part['args']) - yield PartStartEvent(index=next_part_index, part=part) - else: # pragma: no branch - # We don't currently do anything with function_response parts - assert 'function_response' in gemini_part, f'Unhandled part: {gemini_part}' + # The idea: yield only up to the latest response, which might still be partial. + # Note that if the latest response is complete, we could yield it immediately, but there's not a good + # allow_partial API to determine if the last item in the list is complete. + responses_to_yield = gemini_responses[:-1] + for r in responses_to_yield[current_gemini_response_index:]: + current_gemini_response_index += 1 + self._usage += _metadata_as_usage(r) + yield r + + # Now yield the final response, which should be complete + if gemini_responses: + r = gemini_responses[-1] + print(r) + yield r def get(self, *, final: bool = False) -> ModelResponse: """Get the `ModelResponse` at this point. @@ -372,17 +403,7 @@ def get(self, *, final: bool = False) -> ModelResponse: I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from separate parts. """ - responses = _gemini_streamed_response_ta.validate_json( - self._content, - experimental_allow_partial='off' if final else 'trailing-strings', - ) - combined_parts: list[_GeminiPartUnion] = [] - self._usage = result.Usage() - for r in responses: - self._usage += _metadata_as_usage(r) - candidate = r['candidates'][0] - combined_parts.extend(candidate['content']['parts']) - return _process_response_from_parts(combined_parts, timestamp=self._timestamp) + return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self._timestamp) def usage(self) -> result.Usage: return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 544e8855..986e94a2 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -208,10 +208,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. - !!! note - This method will fail if the response is structured, - e.g. if [`is_structured`][pydantic_ai.result.StreamedRunResult.is_structured] returns `True`. - !!! note Result validators will NOT be called on the text result if `delta=True`. @@ -227,9 +223,18 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = ) async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: + # if the response currently has any parts with content, yield those before streaming + # TODO: This needs to be rolled into the group_by_temporal below + msg = self._stream_response.get() + for i, part in enumerate(msg.parts): + # TODO: Probably need to replace this usage of index with a (tracked) part ID or similar + # (It's not guaranteed that this index `i` matches what comes out of the maybe_event.index below) + if isinstance(part, TextPart): + yield part.content, i + async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for group, _is_final in group_iter: - for maybe_event in group: + async for events, _is_final in group_iter: + for maybe_event in events: if ( isinstance(maybe_event, _messages.PartStartEvent) and isinstance(maybe_event.part, _messages.TextPart) @@ -254,7 +259,7 @@ async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: combined_validated_text = '' async for text, index in _stream_text_deltas(): chunks[index] += text - combined_text = ''.join([chunks[k] for k in sorted(chunks)]) + combined_text = '\n\n'.join([chunks[k] for k in sorted(chunks)]) combined_validated_text = await self._validate_text_result(combined_text) yield combined_validated_text @@ -280,6 +285,7 @@ async def stream_structured( with _logfire.span('response stream structured') as lf_span: # if the message currently has any parts with content, yield before streaming + # TODO: This needs to be rolled into the group_by_temporal below... msg = self._stream_response.get() for part in msg.parts: if part.has_content(): diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 63de8cab..ff963e51 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -562,13 +562,21 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] - assert chunks == snapshot(['Hello ', 'Hello world']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert chunks == snapshot( + [ + 'Hello ', + 'Hello world', + # This last value is repeated due to the debounce_by=None combined with the need to emit + # a final empty chunk to signal the end of the stream + 'Hello world', + ] + ) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] - assert chunks == snapshot(['', 'Hello ', 'world']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert chunks == snapshot(['Hello ', 'world']) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index e8a052bc..7f74bafe 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -150,7 +150,7 @@ async def test_streamed_text_stream(): 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.', - # This last value is handled twice due to the debounce_by=None combined with the need to emit + # This last value is repeated due to the debounce_by=None combined with the need to emit # a final empty chunk to signal the end of the stream 'The cat sat on the mat.', ] From 9c1a9b7a224dc1982f25217f48a5940ce0d34ff4 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sun, 5 Jan 2025 12:19:36 -0700 Subject: [PATCH 06/34] Further improve gemini streaming --- pydantic_ai_slim/pydantic_ai/messages.py | 11 ++++-- pydantic_ai_slim/pydantic_ai/models/gemini.py | 38 +++---------------- pydantic_ai_slim/pydantic_ai/models/test.py | 14 ++++--- pydantic_ai_slim/pydantic_ai/result.py | 2 +- tests/models/test_groq.py | 5 --- tests/models/test_mistral.py | 9 ----- tests/models/test_openai.py | 5 --- tests/test_usage_limits.py | 1 - 8 files changed, 23 insertions(+), 62 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 9e68e6a8..a7e40e4d 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -272,7 +272,9 @@ class TextPartDelta: content_delta: str part_delta_kind: Literal['text'] = 'text' - def apply(self, part: TextPart) -> TextPart: + def apply(self, part: ModelResponsePart) -> TextPart: + if not isinstance(part, TextPart): + raise ValueError('Cannot apply TextPartDeltas to non-TextParts') return replace(part, content=part.content + self.content_delta) @@ -283,8 +285,11 @@ class ToolCallPartDelta: args_json_delta: str part_delta_kind: Literal['tool_call'] = 'tool_call' - def apply(self, part: ToolCallPart) -> ToolCallPart: - assert isinstance(part.args, ArgsJson), 'Cannot apply deltas to non-JSON tool arguments' + def apply(self, part: ModelResponsePart) -> ToolCallPart: + if not isinstance(part, ToolCallPart): + raise ValueError('Cannot apply ToolCallPartDeltas to non-ToolCallParts') + if not isinstance(part.args, ArgsJson): + raise ValueError('Cannot apply deltas to non-JSON tool arguments') updated_json = part.args.args_json + self.args_json_delta return replace(part, args=ArgsJson(updated_json)) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c0a5f468..e79760b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -11,7 +11,7 @@ import pydantic from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse -from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never +from typing_extensions import NotRequired, TypedDict, assert_never from .. import UnexpectedModelBehavior, _utils, exceptions, result from ..messages import ( @@ -313,12 +313,16 @@ class GeminiStreamedResponse(StreamedResponse): async def __anext__(self) -> ModelResponseStreamEvent | None: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() + next_event = await self._event_iterator.__anext__() + if next_event is None: + return None - # Update the `_parts` so the `get` method can return the correct value if isinstance(next_event, PartStartEvent): self._parts[next_event.index] = next_event.part elif isinstance(next_event, PartDeltaEvent): + existing_part = self._parts.get(next_event.index) + assert existing_part is not None, 'PartDeltaEvent without existing part' self._parts[next_event.index] = next_event.delta.apply(self._parts[next_event.index]) return next_event @@ -344,7 +348,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | part = TextPart(gemini_part['text']) yield PartStartEvent(index=current_part_index, part=part) else: - delta = TextPartDelta(gemini_part['text']) yield PartDeltaEvent(index=current_part_index, delta=TextPartDelta(gemini_part['text'])) elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. @@ -600,35 +603,6 @@ class _GeminiResponse(TypedDict): prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] -# TODO: Delete the next three functions as part of this PR -def _extract_response_parts( - response: _GeminiResponse, -) -> _utils.Either[list[_GeminiFunctionCallPart], list[_GeminiTextPart]]: - """Extract the parts of the response from the Gemini API. - - Returns Either a list of function calls (Either.left) or a list of text parts (Either.right). - """ - if len(response['candidates']) != 1: - raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') - parts = response['candidates'][0]['content']['parts'] - if _all_function_call_parts(parts): - return _utils.Either(left=parts) - elif _all_text_parts(parts): - return _utils.Either(right=parts) - else: - raise exceptions.UnexpectedModelBehavior( - f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {parts!r}' - ) - - -def _all_function_call_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiFunctionCallPart]]: - return all('function_call' in part for part in parts) - - -def _all_text_parts(parts: list[_GeminiPartUnion]) -> TypeGuard[list[_GeminiTextPart]]: - return all('text' in part for part in parts) - - class _GeminiCandidates(TypedDict): """See .""" diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 1289e2b5..c8386d23 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -220,17 +220,19 @@ class TestStreamedResponse(StreamedResponse): _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent | None: - if not self._parts: - print('---') if self._event_iterator is None: self._event_iterator = self._get_event_iterator() + next_event = await self._event_iterator.__anext__() - existing_part = self._parts.get(next_event.index) - if isinstance(next_event, PartDeltaEvent): + if next_event is None: + return None + + if isinstance(next_event, PartStartEvent): + self._parts[next_event.index] = next_event.part + elif isinstance(next_event, PartDeltaEvent): + existing_part = self._parts.get(next_event.index) assert existing_part is not None, 'PartDeltaEvent without existing part' self._parts[next_event.index] = next_event.delta.apply(existing_part) - else: - self._parts[next_event.index] = next_event.part self._usage += _estimate_event_usage(next_event) return next_event diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 986e94a2..466a08e5 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -293,7 +293,7 @@ async def stream_structured( break async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for events, is_final in group_iter: + async for _events, is_final in group_iter: msg = self._stream_response.get(final=is_final) yield msg, is_final if is_final: diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 93a7cf63..0cda18a9 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -340,7 +340,6 @@ async def test_stream_text(allow_model_requests: None): agent = Agent(m) async with agent.run_stream('') as result: - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete @@ -353,7 +352,6 @@ async def test_stream_text_finish_reason(allow_model_requests: None): agent = Agent(m) async with agent.run_stream('') as result: - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world.']) assert result.is_complete @@ -399,7 +397,6 @@ async def test_stream_structured(allow_model_requests: None): agent = Agent(m, result_type=MyTypedDict) async with agent.run_stream('') as result: - assert result.is_structured assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( [ @@ -447,7 +444,6 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): agent = Agent(m, result_type=MyTypedDict) async with agent.run_stream('') as result: - assert result.is_structured assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( [ @@ -478,7 +474,6 @@ async def test_no_delta(allow_model_requests: None): agent = Agent(m) async with agent.run_stream('') as result: - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 19b807f5..6274310a 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -308,7 +308,6 @@ async def test_stream_text(allow_model_requests: None): # When async with agent.run_stream('') as result: # Then - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot( ['hello ', 'hello world ', 'hello world welcome ', 'hello world welcome mistral'] @@ -329,7 +328,6 @@ async def test_stream_text_finish_reason(allow_model_requests: None): # When async with agent.run_stream('') as result: # Then - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world.']) assert result.is_complete @@ -345,7 +343,6 @@ async def test_no_delta(allow_model_requests: None): # When async with agent.run_stream('') as result: # Then - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete @@ -588,7 +585,6 @@ class MyTypedDict(TypedDict, total=False): # When async with agent.run_stream('User prompt value') as result: # Then - assert result.is_structured assert not result.is_complete v = [dict(c) async for c in result.stream(debounce_by=None)] assert v == snapshot( @@ -702,7 +698,6 @@ class MyTypedDict(TypedDict, total=False): # When async with agent.run_stream('User prompt value') as result: # Then - assert result.is_structured assert not result.is_complete v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot( @@ -761,7 +756,6 @@ async def test_stream_result_type_primitif_int(allow_model_requests: None): # When async with agent.run_stream('User prompt value') as result: # Then - assert result.is_structured assert not result.is_complete v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot([1, 1, 1]) @@ -823,7 +817,6 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None): # When async with agent.run_stream('User prompt value') as result: # Then - assert result.is_structured assert not result.is_complete v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot( @@ -920,7 +913,6 @@ class MyTypedBaseModel(BaseModel): # When async with agent.run_stream('User prompt value') as result: # Then - assert result.is_structured assert not result.is_complete v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot( @@ -1009,7 +1001,6 @@ class MyTypedBaseModel(BaseModel): # When async with agent.run_stream('User prompt value') as result: # Then - assert result.is_structured assert not result.is_complete v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 722f51f7..71e83f21 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -360,7 +360,6 @@ async def test_stream_text(allow_model_requests: None): agent = Agent(m) async with agent.run_stream('') as result: - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete @@ -374,7 +373,6 @@ async def test_stream_text_finish_reason(allow_model_requests: None): agent = Agent(m) async with agent.run_stream('') as result: - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world.']) assert result.is_complete @@ -420,7 +418,6 @@ async def test_stream_structured(allow_model_requests: None): agent = Agent(m, result_type=MyTypedDict) async with agent.run_stream('') as result: - assert result.is_structured assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( [ @@ -449,7 +446,6 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): agent = Agent(m, result_type=MyTypedDict) async with agent.run_stream('') as result: - assert result.is_structured assert not result.is_complete assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( [ @@ -484,7 +480,6 @@ async def test_no_delta(allow_model_requests: None): agent = Agent(m) async with agent.run_stream('') as result: - assert not result.is_structured assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index cbbfec34..877a8977 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -79,7 +79,6 @@ async def ret_a(x: str) -> str: async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: assert test_agent.name == 'test_agent' - assert not result.is_structured assert not result.is_complete assert result.all_messages() == snapshot( [ From 4845dfbd5c7ddd83a17dc08e9e81f67b24726fa1 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sun, 5 Jan 2025 12:27:22 -0700 Subject: [PATCH 07/34] Fix more tests --- pydantic_ai_slim/pydantic_ai/models/function.py | 14 +++++++++++++- pydantic_ai_slim/pydantic_ai/result.py | 2 +- tests/models/test_model_function.py | 2 +- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 02b5610b..25702717 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -168,7 +168,19 @@ async def request_stream( self.stream_function is not None ), 'FunctionModel must receive a `stream_function` to support streamed requests' response_stream = self.stream_function(messages, self.agent_info) - yield FunctionStreamedResponse(response_stream) + + # Explicitly check that we get at least one value, so we can produce a nicer error message for misuse + try: + first = await response_stream.__anext__() + except StopAsyncIteration as e: + raise ValueError('Stream function must return at least one item') from e + + async def peeked_stream(): + yield first + async for item in response_stream: + yield item + + yield FunctionStreamedResponse(peeked_stream()) _CONTENT_PART_INDEX = -1 diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 466a08e5..eb3a22ca 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -229,7 +229,7 @@ async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: for i, part in enumerate(msg.parts): # TODO: Probably need to replace this usage of index with a (tracked) part ID or similar # (It's not guaranteed that this index `i` matches what comes out of the maybe_event.index below) - if isinstance(part, TextPart): + if isinstance(part, TextPart) and part.content: yield part.content, i async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index dec5c2fa..e8da20e9 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -399,7 +399,7 @@ async def test_stream_text(): ModelResponse.from_text(content='hello world', timestamp=IsNow(tz=timezone.utc)), ] ) - assert result.usage() == snapshot(Usage(requests=1)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=50, response_tokens=2, total_tokens=52)) class Foo(BaseModel): From e0ea9c95b601a0ef8b4b26824f9a82d21da81ce0 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Sun, 5 Jan 2025 12:35:45 -0700 Subject: [PATCH 08/34] Fix more tests --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- tests/models/test_gemini.py | 13 +++++++------ tests/test_streaming.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index e79760b8..636fc528 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -394,7 +394,7 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: # Now yield the final response, which should be complete if gemini_responses: r = gemini_responses[-1] - print(r) + self._usage += _metadata_as_usage(r) yield r def get(self, *, final: bool = False) -> ModelResponse: diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ff963e51..218ef481 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -571,12 +571,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): 'Hello world', ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -607,7 +607,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] - assert chunks == snapshot([(1, 2), (1, 2), (1, 2)]) + assert chunks == snapshot([(1, 2), (1, 2)]) assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) @@ -689,6 +689,7 @@ async def bar(y: str) -> str: assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"]) +# TODO: Is this test still necessary now that heterogeneous streaming is allowed? async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): responses = [ gemini_response(_content_model_response(ModelResponse.from_text('Hello '))), @@ -713,10 +714,10 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(m) - msg = 'Streamed response with unexpected content, expected all parts to be text' async with agent.run_stream('Hello') as result: - with pytest.raises(UnexpectedModelBehavior, match=msg): - await result.get_data() + # msg = 'Streamed response with unexpected content, expected all parts to be text' + # with pytest.raises(UnexpectedModelBehavior, match=msg): + await result.get_data() async def test_empty_text_ignored(): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 7f74bafe..5be697e6 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -122,7 +122,7 @@ async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> A assert chunks == snapshot([[1], [1, 2, 3, 4], [1, 2, 3, 4]]) - # TODO: Is the following check still relevant? I'm not sure what it's trying to do... + # TODO: Can we remove the following check now? # async with agent.run_stream('Hello') as result: # with pytest.raises(UserError, match=r'stream_text\(\) can only be used with text responses'): # async for _ in result.stream_text(): @@ -161,7 +161,7 @@ async def test_streamed_text_stream(): ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] ) - # TODO: Is the following check still relevant? I'm not sure what it's trying to do... + # TODO: Can we remove the following check now? # async with agent.run_stream('Hello') as result: # with pytest.raises(UserError, match=r'stream_structured\(\) can only be used with structured responses'): # async for _ in result.stream_structured(): From d8bed7a32c58cf2520905a7511ea57705a126843 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:15:22 -0700 Subject: [PATCH 09/34] Fix groq tests and examples --- docs/message-history.md | 2 +- docs/results.md | 4 +- .../pydantic_ai/models/function.py | 43 +++-- pydantic_ai_slim/pydantic_ai/models/gemini.py | 11 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 181 ++++++++++++------ pydantic_ai_slim/pydantic_ai/result.py | 2 +- tests/models/test_groq.py | 13 +- tests/test_streaming.py | 10 +- 8 files changed, 169 insertions(+), 97 deletions(-) diff --git a/docs/message-history.md b/docs/message-history.md index a8e94209..6d058574 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -98,7 +98,7 @@ async def main(): ] """ - async for text in result.stream(): + async for text in result.stream_text(): print(text) #> Did you hear #> Did you hear about the toothpaste diff --git a/docs/results.md b/docs/results.md index 292ed11f..802619a8 100644 --- a/docs/results.md +++ b/docs/results.md @@ -176,7 +176,7 @@ agent = Agent('gemini-1.5-flash') # (1)! async def main(): async with agent.run_stream('Where does "hello world" come from?') as result: # (2)! - async for message in result.stream(): # (3)! + async for message in result.stream_text(): # (3)! print(message) #> The first known #> The first known use of "hello, @@ -188,7 +188,7 @@ async def main(): 1. Streaming works with the standard [`Agent`][pydantic_ai.Agent] class, and doesn't require any special setup, just a model that supports streaming (currently all models support streaming). 2. The [`Agent.run_stream()`][pydantic_ai.Agent.run_stream] method is used to start a streamed run, this method returns a context manager so the connection can be closed when the stream completes. -3. Each item yield by [`StreamedRunResult.stream()`][pydantic_ai.result.StreamedRunResult.stream] is the complete text response, extended as new data is received. +3. Each item yield by [`StreamedRunResult.stream_text()`][pydantic_ai.result.StreamedRunResult.stream_text] is the complete text response, extended as new data is received. _(This example is complete, it can be run "as is")_ diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 25702717..2ccc3b21 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -183,9 +183,6 @@ async def peeked_stream(): yield FunctionStreamedResponse(peeked_stream()) -_CONTENT_PART_INDEX = -1 - - @dataclass class FunctionStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" @@ -193,6 +190,9 @@ class FunctionStreamedResponse(StreamedResponse): _iter: AsyncIterator[str | DeltaToolCalls] _timestamp: datetime = field(default_factory=_utils.now_utc) + _next_part_index: int = field(default=0, init=False) + _content_part_index: int | None = field(default=None, init=False) + _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) @@ -203,37 +203,46 @@ async def __anext__(self) -> ModelResponseStreamEvent | None: async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: async for item in self._iter: + # TODO: Create a PartsStreamManager class that wraps + # _next_part_index, _content_part_index, _tool_call_index_to_part_index, and _parts + # and is reusable across the different StreamedResponse implementations. if isinstance(item, str): text = item - content_part = self._parts.get(_CONTENT_PART_INDEX) - if content_part is None: + if self._content_part_index is None: content_part = TextPart(content=text) - self._parts[_CONTENT_PART_INDEX] = content_part - yield PartStartEvent(index=_CONTENT_PART_INDEX, part=content_part) + self._content_part_index = self._next_part_index + self._parts[self._content_part_index] = content_part + self._next_part_index += 1 + yield PartStartEvent(index=self._content_part_index, part=content_part) else: - assert isinstance(content_part, TextPart), 'Cannot switch to text part mid-stream' + content_part = self._parts[self._content_part_index] + assert isinstance(content_part, TextPart), 'The content part must be a text part' delta = TextPartDelta(content_delta=text) - self._parts[_CONTENT_PART_INDEX] = delta.apply(content_part) - yield PartDeltaEvent(index=_CONTENT_PART_INDEX, delta=delta) + self._parts[self._content_part_index] = delta.apply(content_part) + yield PartDeltaEvent(index=self._content_part_index, delta=delta) else: delta_tool_calls = item - for index, delta_tool_call in delta_tool_calls.items(): - existing_part = self._parts.get(index) - if existing_part is None: + for dtc_index, delta_tool_call in delta_tool_calls.items(): + existing_part_index = self._tool_call_index_to_part_index.get(dtc_index) + if existing_part_index is None: + new_part_index = self._next_part_index + self._tool_call_index_to_part_index[dtc_index] = new_part_index + self._next_part_index += 1 assert ( delta_tool_call.name is not None ), 'The first delta_tool_call with a given index must include a tool name' part = ToolCallPart.from_raw_args( tool_name=delta_tool_call.name, args=delta_tool_call.json_args or '' ) - self._parts[index] = part - yield PartStartEvent(index=index, part=part) + self._parts[new_part_index] = part + yield PartStartEvent(index=new_part_index, part=part) else: + existing_part = self._parts[existing_part_index] assert isinstance(existing_part, ToolCallPart), 'Cannot switch to tool call part mid-stream' if delta_tool_call.json_args is not None: delta = ToolCallPartDelta(delta_tool_call.json_args) - self._parts[index] = delta.apply(existing_part) - yield PartDeltaEvent(index=index, delta=delta) + self._parts[existing_part_index] = delta.apply(existing_part) + yield PartDeltaEvent(index=existing_part_index, delta=delta) def get(self, *, final: bool = False) -> ModelResponse: parts = [self._parts[index] for index in sorted(self._parts)] diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 636fc528..cf0c29a0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -315,8 +315,6 @@ async def __anext__(self) -> ModelResponseStreamEvent | None: self._event_iterator = self._get_event_iterator() next_event = await self._event_iterator.__anext__() - if next_event is None: - return None if isinstance(next_event, PartStartEvent): self._parts[next_event.index] = next_event.part @@ -398,14 +396,7 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: yield r def get(self, *, final: bool = False) -> ModelResponse: - """Get the `ModelResponse` at this point. - - NOTE: It's not clear how the stream of responses should be combined because Gemini seems to always - reply with a single response, when returning a structured data. - - I'm therefore assuming that each part contains a complete tool call, and not trying to combine data from - separate parts. - """ + """Get the `ModelResponse` at this point.""" return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self._timestamp) def usage(self) -> result.Usage: diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 3f04ae9a..fad0a186 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -10,17 +10,22 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import UnexpectedModelBehavior, _utils, result +from .. import _utils, result from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, ModelRequest, ModelResponse, ModelResponsePart, + ModelResponseStreamEvent, + PartDeltaEvent, + PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, + TextPartDelta, ToolCallPart, + ToolCallPartDelta, ToolReturnPart, UserPromptPart, ) @@ -165,7 +170,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: response = await self._completions_create(messages, True, model_settings) async with response: - yield await self._process_streamed_response(response) + yield GroqStreamedResponse(response) @overload async def _completions_create( @@ -221,34 +226,6 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id)) return ModelResponse(items, timestamp=timestamp) - @staticmethod - async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> StreamedResponse: - """Process a streamed response, and prepare a streaming response to return.""" - timestamp: datetime | None = None - start_usage = Usage() - # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content` - while True: - try: - chunk = await response.__anext__() - except StopAsyncIteration as e: - raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e - timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) - start_usage += _map_usage(chunk) - - if chunk.choices: - delta = chunk.choices[0].delta - - if delta.content is not None: - # return GroqStreamTextResponse(delta.content, response, timestamp, start_usage) - raise NotImplementedError('Fix this branch') - elif delta.tool_calls is not None: - return GroqStreamedResponse( - response, - {c.index: c for c in delta.tool_calls}, - timestamp, - start_usage, - ) - @classmethod def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" @@ -304,48 +281,132 @@ class GroqStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Groq models.""" _response: AsyncStream[ChatCompletionChunk] - _delta_tool_calls: dict[int, ChoiceDeltaToolCall] - _timestamp: datetime - _usage: result.Usage - async def __anext__(self) -> None: - chunk = await self._response.__anext__() - self._usage = _map_usage(chunk) + _timestamp: datetime | None = field(default=None, init=False) + _usage: result.Usage = field(default_factory=result.Usage, init=False) + _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) + _content_part_index: int | None = field(default=None, init=False) + _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) + _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) - try: - choice = chunk.choices[0] - except IndexError: - raise StopAsyncIteration() + async def __anext__(self) -> ModelResponseStreamEvent | None: + if self._event_iterator is None: + self._event_iterator = self._get_event_iterator() - if choice.finish_reason is not None: - raise StopAsyncIteration() + next_event = await self._event_iterator.__anext__() - assert choice.delta.content is None, f'Expected tool calls, got content instead, invalid chunk: {chunk!r}' + if isinstance(next_event, PartStartEvent): + self._parts[next_event.index] = next_event.part + elif isinstance(next_event, PartDeltaEvent): + existing_part = self._parts.get(next_event.index) + assert existing_part is not None, 'PartDeltaEvent without existing part' + self._parts[next_event.index] = next_event.delta.apply(self._parts[next_event.index]) - for new in choice.delta.tool_calls or []: - if current := self._delta_tool_calls.get(new.index): - if current.function is None: - current.function = new.function - elif new.function is not None: - current.function.name = _utils.add_optional(current.function.name, new.function.name) - current.function.arguments = _utils.add_optional(current.function.arguments, new.function.arguments) - else: - self._delta_tool_calls[new.index] = new + return next_event - def get(self, *, final: bool = False) -> ModelResponse: - items: list[ModelResponsePart] = [] - for c in self._delta_tool_calls.values(): - if f := c.function: - if f.name is not None and f.arguments is not None: - items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id)) + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: # noqa C901 + # TODO: Simplify this through the use of a StreamedPartsManager or whatever + current_part_index: int | None = None + + async for chunk in self._response: + self._timestamp = self._timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) + self._usage += _map_usage(chunk) - return ModelResponse(items, timestamp=self._timestamp) + if not chunk.choices: + continue + choice = chunk.choices[0] + + # Handle the text part of the response + content = choice.delta.content + if content is not None: + if self._content_part_index is not None: + yield PartDeltaEvent(index=self._content_part_index, delta=TextPartDelta(content)) + else: + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + + self._content_part_index = current_part_index + part = TextPart(content) + yield PartStartEvent(index=current_part_index, part=part) + + # Handle the tool calls + for dtc in choice.delta.tool_calls or []: + if not dtc.function: + continue + + if existing := self._delta_tool_calls.get(dtc.index): + # We've already received a delta_tool_call for this index + existing.id = existing.id or dtc.id + if existing.function is None: + existing.function = dtc.function + else: + existing.function.name = _utils.add_optional(existing.function.name, dtc.function.name) + existing.function.arguments = _utils.add_optional( + existing.function.arguments, dtc.function.arguments + ) + + if not (existing.function.name and existing.function.arguments): + continue # We don't have enough information to create a part + + part_index = self._tool_call_index_to_part_index.get(dtc.index) + if part_index is None: + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + self._tool_call_index_to_part_index[dtc.index] = current_part_index + yield PartStartEvent( + index=current_part_index, + part=ToolCallPart.from_raw_args( + existing.function.name, existing.function.arguments, tool_call_id=existing.id + ), + ) + elif dtc.function.name: + # We don't currently nicely support streaming updates to the function call name + # So we just replace the whole part if the name has changed + yield PartStartEvent( + index=part_index, + part=ToolCallPart.from_raw_args( + existing.function.name, existing.function.arguments, tool_call_id=existing.id + ), + ) + elif dtc.function.arguments: + yield PartDeltaEvent( + index=part_index, + delta=ToolCallPartDelta(dtc.function.arguments), + ) + + else: + self._delta_tool_calls[dtc.index] = dtc + + if dtc.function.name and dtc.function.arguments: + # This is the first delta_tool_call we've received with this index + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + self._tool_call_index_to_part_index[dtc.index] = current_part_index + yield PartStartEvent( + index=current_part_index, + part=ToolCallPart.from_raw_args( + dtc.function.name, dtc.function.arguments, tool_call_id=dtc.id + ), + ) + + def get(self, *, final: bool = False) -> ModelResponse: + return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self.timestamp()) def usage(self) -> Usage: return self._usage def timestamp(self) -> datetime: - return self._timestamp + return self._timestamp or datetime.now(tz=timezone.utc) def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index eb3a22ca..3176cc8f 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -259,7 +259,7 @@ async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: combined_validated_text = '' async for text, index in _stream_text_deltas(): chunks[index] += text - combined_text = '\n\n'.join([chunks[k] for k in sorted(chunks)]) + combined_text = ''.join([chunks[k] for k in sorted(chunks)]) combined_validated_text = await self._validate_text_result(combined_text) yield combined_validated_text diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 0cda18a9..4fd250b4 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -61,6 +61,9 @@ class MockAsyncStream: async def __anext__(self) -> chat.ChatCompletionChunk: return _utils.sync_anext(self._iter) + def __aiter__(self) -> MockAsyncStream: + return self + async def __aenter__(self): return self @@ -341,7 +344,7 @@ async def test_stream_text(allow_model_requests: None): async with agent.run_stream('') as result: assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world']) assert result.is_complete @@ -353,7 +356,9 @@ async def test_stream_text_finish_reason(allow_model_requests: None): async with agent.run_stream('') as result: assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world.']) + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.', 'hello world.'] + ) assert result.is_complete @@ -462,7 +467,7 @@ async def test_no_content(allow_model_requests: None): m = GroqModel('llama-3.1-70b-versatile', groq_client=mock_client) agent = Agent(m, result_type=MyTypedDict) - with pytest.raises(UnexpectedModelBehavior, match='Streamed response ended without con'): + with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): async with agent.run_stream(''): pass # pragma: no cover @@ -475,5 +480,5 @@ async def test_no_delta(allow_model_requests: None): async with agent.run_stream('') as result: assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world']) assert result.is_complete diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 5be697e6..1ad34f3b 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -150,12 +150,18 @@ async def test_streamed_text_stream(): 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.', - # This last value is repeated due to the debounce_by=None combined with the need to emit - # a final empty chunk to signal the end of the stream + # This last value is repeated due to the debounce_by=None combined with the need to emit a final empty + # chunk to signal the end of the stream (which is used to determine whether to allow partial JSON) 'The cat sat on the mat.', ] ) + async with agent.run_stream('Hello') as result: + # with stream_text, there is no need to do partial validation, so we only get the final message once: + assert [c async for c in result.stream_text(delta=False, debounce_by=None)] == snapshot( + ['The ', 'The cat ', 'The cat sat ', 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.'] + ) + async with agent.run_stream('Hello') as result: assert [c async for c in result.stream_text(delta=True, debounce_by=None)] == snapshot( ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] From 4d75d6df8d53045feb0343c9d42e61531776916b Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:44:45 -0700 Subject: [PATCH 10/34] Use peekable stream to access timestamp in groq --- pydantic_ai_slim/pydantic_ai/_utils.py | 66 +++++++++++++++++++++ pydantic_ai_slim/pydantic_ai/models/groq.py | 27 ++++++--- 2 files changed, 85 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 751e92d9..04a4c2b0 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -262,3 +262,69 @@ def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model """Type guard that checks a `tool_call_id` is not None both for static typing and runtime.""" assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}' return t.tool_call_id + + +class PeekableAsyncStream(Generic[T]): + """Wraps an async iterable of type T and allows peeking at the *next* item without consuming it. + + We only buffer one item at a time (the next item). Once that item is yielded, it is discarded. + This is a single-pass stream. + """ + + def __init__(self, source: AsyncIterable[T]): + self._source = source + self._source_iter: AsyncIterator[T] | None = None + self._buffer: T | None = None + self._exhausted = False + + async def peek(self) -> T | None: + """Returns the next item that would be yielded without consuming it. + + Returns None if the stream is exhausted. + """ + if self._exhausted: + return None + + # If we already have a buffered item, just return it. + if self._buffer is not None: + return self._buffer + + # Otherwise, we need to fetch the next item from the underlying iterator. + if self._source_iter is None: + self._source_iter = self._source.__aiter__() + + try: + self._buffer = await self._source_iter.__anext__() + except StopAsyncIteration: + self._exhausted = True + return None + + return self._buffer + + def __aiter__(self) -> AsyncIterator[T]: + # For a single-pass iteration, we can return self as the iterator. + return self + + async def __anext__(self) -> T: + """Yields the buffered item if present, otherwise fetches the next item from the underlying source. + + Raises StopAsyncIteration if the stream is exhausted. + """ + if self._exhausted: + raise StopAsyncIteration + + # If we have a buffered item, yield it. + if self._buffer is not None: + item = self._buffer + self._buffer = None + return item + + # Otherwise, fetch the next item from the source. + if self._source_iter is None: + self._source_iter = self._source.__aiter__() + + try: + return await self._source_iter.__anext__() + except StopAsyncIteration: + self._exhausted = True + raise diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index fad0a186..6978a154 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone @@ -10,8 +10,8 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import _utils, result -from .._utils import guard_tool_call_id as _guard_tool_call_id +from .. import UnexpectedModelBehavior, _utils, result +from .._utils import PeekableAsyncStream, guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, ModelRequest, @@ -170,7 +170,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: response = await self._completions_create(messages, True, model_settings) async with response: - yield GroqStreamedResponse(response) + yield await self._process_streamed_response(response) @overload async def _completions_create( @@ -226,6 +226,16 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id)) return ModelResponse(items, timestamp=timestamp) + @staticmethod + async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse: + """Process a streamed response, and prepare a streaming response to return.""" + peekable_response = PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if first_chunk is None: + raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') + + return GroqStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)) + @classmethod def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" @@ -280,10 +290,11 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio class GroqStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Groq models.""" - _response: AsyncStream[ChatCompletionChunk] + _response: AsyncIterable[ChatCompletionChunk] + _timestamp: datetime - _timestamp: datetime | None = field(default=None, init=False) _usage: result.Usage = field(default_factory=result.Usage, init=False) + _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) _content_part_index: int | None = field(default=None, init=False) _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) @@ -400,13 +411,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | ) def get(self, *, final: bool = False) -> ModelResponse: - return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self.timestamp()) + return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self._timestamp) def usage(self) -> Usage: return self._usage def timestamp(self) -> datetime: - return self._timestamp or datetime.now(tz=timezone.utc) + return self._timestamp def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: From 7b959c466067d19e06577c52f7fea59a3276e405 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:04:09 -0700 Subject: [PATCH 11/34] Get openai tests passing --- pydantic_ai_slim/pydantic_ai/models/groq.py | 5 +-- pydantic_ai_slim/pydantic_ai/models/openai.py | 31 ++++++++++++------- tests/models/test_openai.py | 14 ++++++--- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 6978a154..2ab56103 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -324,9 +324,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | self._timestamp = self._timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) self._usage += _map_usage(chunk) - if not chunk.choices: + try: + choice = chunk.choices[0] + except IndexError: continue - choice = chunk.choices[0] # Handle the text part of the response content = choice.delta.content diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 4cc05166..2a7041d8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Iterable, Iterator +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone @@ -10,8 +10,8 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import _utils, result -from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc +from .. import UnexpectedModelBehavior, _utils, result +from .._utils import PeekableAsyncStream, guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( ModelMessage, ModelRequest, @@ -159,7 +159,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: response = await self._completions_create(messages, True, model_settings) async with response: - yield OpenAIStreamedResponse(response) + yield await self._process_streamed_response(response) @overload async def _completions_create( @@ -216,6 +216,16 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id)) return ModelResponse(items, timestamp=timestamp) + @staticmethod + async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: + """Process a streamed response, and prepare a streaming response to return.""" + peekable_response = PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if first_chunk is None: + raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') + + return OpenAIStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)) + @classmethod def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" @@ -275,12 +285,12 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio class OpenAIStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for OpenAI models.""" - _response: AsyncStream[ChatCompletionChunk] + _response: AsyncIterable[ChatCompletionChunk] + _timestamp: datetime - _timestamp: datetime | None = field(default=None, init=False) _usage: result.Usage = field(default_factory=result.Usage, init=False) - _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) + _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) _content_part: TextPart | None = field(default=None, init=False) _tool_call_parts: dict[int, ToolCallPart] = field(default_factory=dict, init=False) _async_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) @@ -304,7 +314,7 @@ async def _get_async_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | try: choice = chunk.choices[0] except IndexError: - raise StopAsyncIteration() + continue for e in self._update_parts_for_content_delta(choice.delta.content): yield e @@ -321,16 +331,13 @@ async def _get_async_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | current.function.arguments = _utils.add_optional( current.function.arguments, new.function.arguments ) - for e in self._update_parts_for_tool_call_delta(current, replace_existing_part): + for e in self._update_parts_for_tool_call_delta(new, replace_existing_part): yield e else: self._delta_tool_calls[new.index] = new for e in self._update_parts_for_tool_call_delta(new, True): yield e - if choice.finish_reason is not None: - raise StopAsyncIteration() - def get(self, *, final: bool = False) -> ModelResponse: items: list[ModelResponsePart] = [self._content_part] if self._content_part is not None else [] items.extend([self._tool_call_parts[k] for k in sorted(self._tool_call_parts.keys())]) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 71e83f21..f46f16f9 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -70,6 +70,9 @@ class MockAsyncStream: async def __anext__(self) -> chat.ChatCompletionChunk: return _utils.sync_anext(self._iter) + def __aiter__(self) -> MockAsyncStream: + return self + async def __aenter__(self): return self @@ -361,7 +364,7 @@ async def test_stream_text(allow_model_requests: None): async with agent.run_stream('') as result: assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) @@ -374,7 +377,9 @@ async def test_stream_text_finish_reason(allow_model_requests: None): async with agent.run_stream('') as result: assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world.']) + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.'] + ) assert result.is_complete @@ -453,6 +458,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, ] ) assert result.is_complete @@ -464,7 +470,7 @@ async def test_no_content(allow_model_requests: None): m = OpenAIModel('gpt-4', openai_client=mock_client) agent = Agent(m, result_type=MyTypedDict) - with pytest.raises(UnexpectedModelBehavior, match='Streamed response ended without con'): + with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): async with agent.run_stream(''): pass @@ -481,6 +487,6 @@ async def test_no_delta(allow_model_requests: None): async with agent.run_stream('') as result: assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) From c5590cc4adf68b756d02db258d80ed38a31f93bf Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:41:33 -0700 Subject: [PATCH 12/34] Fix mistral tests --- pydantic_ai_slim/pydantic_ai/models/groq.py | 2 +- .../pydantic_ai/models/mistral.py | 202 +++++++++++------- tests/models/test_mistral.py | 19 +- 3 files changed, 137 insertions(+), 86 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 2ab56103..38be7e8e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -296,6 +296,7 @@ class GroqStreamedResponse(StreamedResponse): _usage: result.Usage = field(default_factory=result.Usage, init=False) _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) + _content_part_index: int | None = field(default=None, init=False) _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) @@ -321,7 +322,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | current_part_index: int | None = None async for chunk in self._response: - self._timestamp = self._timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) self._usage += _map_usage(chunk) try: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 1dae05d6..ee48e8f8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import os -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone @@ -13,16 +13,20 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior -from .._utils import now_utc as _now_utc +from .._utils import PeekableAsyncStream, now_utc as _now_utc from ..messages import ( ArgsJson, ModelMessage, ModelRequest, ModelResponse, ModelResponsePart, + ModelResponseStreamEvent, + PartDeltaEvent, + PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, + TextPartDelta, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -295,44 +299,17 @@ async def _process_streamed_response( response: MistralEventStreamAsync[MistralCompletionEvent], ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" - start_usage = Usage() + peekable_response = PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if first_chunk is None: + raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') - # Iterate until we get either `tool_calls` or `content` from the first chunk. - while True: - try: - event = await response.__anext__() - chunk = event.data - except StopAsyncIteration as e: - raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e - - start_usage += _map_usage(chunk) - - if chunk.created: - timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc) - else: - timestamp = _now_utc() - - if chunk.choices: - delta = chunk.choices[0].delta - content = _map_content(delta.content) - - tool_calls: list[MistralToolCall] | None = None - if delta.tool_calls: - tool_calls = delta.tool_calls - - if tool_calls or content and result_tools: - return MistralStreamedResponse( - {c.id if c.id else 'null': c for c in tool_calls or []}, - {c.name: c for c in result_tools}, - response, - content, - timestamp, - start_usage, - ) + if first_chunk.data.created: + timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc) + else: + timestamp = datetime.now(tz=timezone.utc) - elif content: - # return MistralStreamTextResponse(content, response, timestamp, start_usage) - raise NotImplementedError('TODO: Fix this branch') + return MistralStreamedResponse(peekable_response, timestamp, {c.name: c for c in result_tools}) @staticmethod def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall: @@ -466,59 +443,111 @@ def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]: assert_never(message) +MistralToolCallId = str | None + + @dataclass class MistralStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Mistral models.""" - _function_tools: dict[str, MistralToolCall] - _result_tools: dict[str, ToolDefinition] - _response: MistralEventStreamAsync[MistralCompletionEvent] - _delta_content: str | None + _response: AsyncIterable[MistralCompletionEvent] _timestamp: datetime - _usage: Usage + _result_tools: dict[str, ToolDefinition] - async def __anext__(self) -> None: - chunk = await self._response.__anext__() - self._usage += _map_usage(chunk.data) + _usage: Usage = field(default_factory=Usage, init=False) - try: - choice = chunk.data.choices[0] + _function_tools: dict[str, MistralToolCall] = field(default_factory=dict, init=False) + _delta_content: str = '' - except IndexError: - raise StopAsyncIteration() + _delta_tool_calls: dict[MistralToolCallId, MistralToolCall] = field(default_factory=dict, init=False) - if choice.finish_reason is not None: - raise StopAsyncIteration() + _result_part_index: int | None = field(default=None, init=False) + _content_part_index: int | None = field(default=None, init=False) + _tool_call_id_to_part_index: dict[MistralToolCallId, int] = field(default_factory=dict, init=False) + _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) - content = choice.delta.content - if self._result_tools: - if text := _map_content(content): - self._delta_content = (self._delta_content or '') + text + async def __anext__(self) -> ModelResponseStreamEvent | None: + if self._event_iterator is None: + self._event_iterator = self._get_event_iterator() - def get(self, *, final: bool = False) -> ModelResponse: - calls: list[ModelResponsePart] = [] - if self._function_tools and self._result_tools or self._function_tools: - for tool_call in self._function_tools.values(): - tool = _map_mistral_to_pydantic_tool_call(tool_call) - calls.append(tool) + next_event = await self._event_iterator.__anext__() - elif self._delta_content and self._result_tools: - output_json: dict[str, Any] | None = pydantic_core.from_json( - self._delta_content, allow_partial='trailing-strings' - ) + if isinstance(next_event, PartStartEvent): + self._parts[next_event.index] = next_event.part + elif isinstance(next_event, PartDeltaEvent): + existing_part = self._parts.get(next_event.index) + assert existing_part is not None, 'PartDeltaEvent without existing part' + self._parts[next_event.index] = next_event.delta.apply(self._parts[next_event.index]) + + return next_event - if output_json: - for result_tool in self._result_tools.values(): - # NOTE: Additional verification to prevent JSON validation to crash in `_result.py` - # Ensures required parameters in the JSON schema are respected, especially for stream-based return types. - # Example with BaseModel and required fields. - if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema): - continue + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + # TODO: Simplify this through the use of a StreamedPartsManager or whatever + current_part_index: int | None = None - tool = ToolCallPart.from_raw_args(result_tool.name, output_json) - calls.append(tool) + chunk: MistralCompletionEvent + async for chunk in self._response: + self._usage += _map_usage(chunk.data) + + try: + choice = chunk.data.choices[0] + except IndexError: + continue + + # Handle the text part of the response + content = choice.delta.content + text = _map_content(content) + if text: + if self._result_tools: + self._delta_content += text + maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools) + if maybe_tool_call_part: + if self._result_part_index is None: + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + self._result_part_index = current_part_index + + yield PartStartEvent( + index=self._result_part_index, + part=maybe_tool_call_part, + ) + else: + if self._content_part_index is not None: + yield PartDeltaEvent(index=self._content_part_index, delta=TextPartDelta(text)) + else: + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + + self._content_part_index = current_part_index + part = TextPart(text) + yield PartStartEvent(index=current_part_index, part=part) + + # Handle the tool calls + for dtc in choice.delta.tool_calls or []: + # It seems that mistral just sends full tool calls, so we just use them directly, rather than building + part_index = self._tool_call_id_to_part_index.get(dtc.id) + if part_index is None: + if current_part_index is None: + current_part_index = 0 + else: + # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. + current_part_index += 1 + self._tool_call_id_to_part_index[dtc.id] = current_part_index + part_index = current_part_index + yield PartStartEvent( + index=part_index, + part=ToolCallPart.from_raw_args(dtc.function.name, dtc.function.arguments, tool_call_id=dtc.id), + ) - return ModelResponse(calls, timestamp=self._timestamp) + def get(self, *, final: bool = False) -> ModelResponse: + return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self._timestamp) def usage(self) -> Usage: return self._usage @@ -526,6 +555,21 @@ def usage(self) -> Usage: def timestamp(self) -> datetime: return self._timestamp + @staticmethod + def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None: + output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings') + if output_json: + for result_tool in result_tools.values(): + # NOTE: Additional verification to prevent JSON validation to crash in `_result.py` + # Ensures required parameters in the JSON schema are respected, especially for stream-based return types. + # Example with BaseModel and required fields. + if not MistralStreamedResponse._validate_required_json_schema( + output_json, result_tool.parameters_json_schema + ): + continue + + return ToolCallPart.from_raw_args(result_tool.name, output_json) + @staticmethod def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool: """Validate that all required parameters in the JSON schema are present in the JSON dictionary.""" @@ -544,9 +588,9 @@ def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[ if not isinstance(json_dict[param], list): return False for item in json_dict[param]: - if not isinstance(item, VALIDE_JSON_TYPE_MAPPING[param_items_type]): + if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]): return False - elif param_type and not isinstance(json_dict[param], VALIDE_JSON_TYPE_MAPPING[param_type]): + elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]): return False if isinstance(json_dict[param], dict) and 'properties' in param_schema: @@ -557,7 +601,7 @@ def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[ return True -VALIDE_JSON_TYPE_MAPPING: dict[str, Any] = { +VALID_JSON_TYPE_MAPPING: dict[str, Any] = { 'string': str, 'integer': int, 'number': float, diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 6274310a..ff9d4fd3 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -68,6 +68,9 @@ class MockAsyncStream: async def __anext__(self) -> MistralCompletionChunk: return _utils.sync_anext(self._iter) + def __aiter__(self): + return self + async def __aenter__(self): return self @@ -309,7 +312,7 @@ async def test_stream_text(allow_model_requests: None): async with agent.run_stream('') as result: # Then assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot( + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( ['hello ', 'hello world ', 'hello world welcome ', 'hello world welcome mistral'] ) assert result.is_complete @@ -329,7 +332,9 @@ async def test_stream_text_finish_reason(allow_model_requests: None): async with agent.run_stream('') as result: # Then assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world', 'hello world.']) + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.'] + ) assert result.is_complete @@ -344,7 +349,7 @@ async def test_no_delta(allow_model_requests: None): async with agent.run_stream('') as result: # Then assert not result.is_complete - assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete assert result.usage().request_tokens == 3 assert result.usage().response_tokens == 3 @@ -1457,7 +1462,7 @@ async def get_location(loc_name: str) -> str: # Then assert not result.is_complete v = [c async for c in result.stream(debounce_by=None)] - assert v == snapshot(['final ', 'final response']) + assert v == snapshot(['final ', 'final response', 'final response']) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) assert result.usage().request_tokens == 6 @@ -1495,7 +1500,9 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), + ModelResponse.from_text( + content='final response', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) + ), ] ) @@ -1557,7 +1564,7 @@ async def get_location(loc_name: str) -> str: async with agent.run_stream('User prompt value') as result: # Then assert not result.is_complete - v = [c async for c in result.stream(debounce_by=None)] + v = [c async for c in result.stream_text(debounce_by=None)] assert v == snapshot(['final ', 'final response']) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) From ff3637704fe9e674594e75784e2ef1a1993c5e07 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:55:05 -0700 Subject: [PATCH 13/34] Remove the ability to yield None from StreamingResponse iterator --- pydantic_ai_slim/pydantic_ai/agent.py | 2 -- pydantic_ai_slim/pydantic_ai/models/__init__.py | 12 ++++-------- pydantic_ai_slim/pydantic_ai/models/function.py | 6 +++--- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +++--- pydantic_ai_slim/pydantic_ai/models/groq.py | 6 +++--- pydantic_ai_slim/pydantic_ai/models/mistral.py | 6 +++--- pydantic_ai_slim/pydantic_ai/models/openai.py | 8 ++++---- pydantic_ai_slim/pydantic_ai/models/test.py | 10 ++++------ pydantic_ai_slim/pydantic_ai/result.py | 4 ++-- 9 files changed, 26 insertions(+), 34 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index db12fa86..78f9101b 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1011,8 +1011,6 @@ async def _handle_streamed_response( received_text = False async for maybe_part_event in streamed_response: - if maybe_part_event is None: - continue if isinstance(maybe_part_event, PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, TextPart): diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index f9c45e2a..f9fc92ce 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -142,17 +142,13 @@ async def request_stream( class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" - def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent | None]: - """Stream the response as an async iterable of (optional) `ModelResponseStreamEvent`s. - - This is an async iterator that yields events as they are received. It may yield `None` when raw data is received - from the model but there is not enough information to produce a meaningful ModelResponseStreamEvent. - """ + def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: + """Stream the response as an async iterable of (optional) `ModelResponseStreamEvent`s.""" return self @abstractmethod - async def __anext__(self) -> ModelResponseStreamEvent | None: - """Process the next chunk of the response, see above for why this may return `None`.""" + async def __anext__(self) -> ModelResponseStreamEvent: + """Process the next chunk of the response.""" raise NotImplementedError() @abstractmethod diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 2ccc3b21..b56875f6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -194,14 +194,14 @@ class FunctionStreamedResponse(StreamedResponse): _content_part_index: int | None = field(default=None, init=False) _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - async def __anext__(self) -> ModelResponseStreamEvent | None: + async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() return await self._event_iterator.__anext__() - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: # TODO: Create a PartsStreamManager class that wraps # _next_part_index, _content_part_index, _tool_call_index_to_part_index, and _parts diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c20cf244..be330063 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -308,9 +308,9 @@ class GeminiStreamedResponse(StreamedResponse): _usage: result.Usage = field(default_factory=result.Usage, init=False) _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - async def __anext__(self) -> ModelResponseStreamEvent | None: + async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() @@ -325,7 +325,7 @@ async def __anext__(self) -> ModelResponseStreamEvent | None: return next_event - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: current_part_index: int | None = None current_tool_call_name: str | None = None # None means we are in a text part or have no parts at all diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 38be7e8e..4eb742ef 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -300,9 +300,9 @@ class GroqStreamedResponse(StreamedResponse): _content_part_index: int | None = field(default=None, init=False) _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - async def __anext__(self) -> ModelResponseStreamEvent | None: + async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() @@ -317,7 +317,7 @@ async def __anext__(self) -> ModelResponseStreamEvent | None: return next_event - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: # noqa C901 + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa C901 # TODO: Simplify this through the use of a StreamedPartsManager or whatever current_part_index: int | None = None diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index ee48e8f8..9aaaf4d8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -465,9 +465,9 @@ class MistralStreamedResponse(StreamedResponse): _content_part_index: int | None = field(default=None, init=False) _tool_call_id_to_part_index: dict[MistralToolCallId, int] = field(default_factory=dict, init=False) _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - async def __anext__(self) -> ModelResponseStreamEvent | None: + async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() @@ -482,7 +482,7 @@ async def __anext__(self) -> ModelResponseStreamEvent | None: return next_event - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # TODO: Simplify this through the use of a StreamedPartsManager or whatever current_part_index: int | None = None diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 2a7041d8..0c40353f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -293,19 +293,19 @@ class OpenAIStreamedResponse(StreamedResponse): _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) _content_part: TextPart | None = field(default=None, init=False) _tool_call_parts: dict[int, ToolCallPart] = field(default_factory=dict, init=False) - _async_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + _async_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: if self._async_iterator is None: self._async_iterator = self._get_async_iterator() return self._async_iterator - async def __anext__(self) -> ModelResponseStreamEvent | None: + async def __anext__(self) -> ModelResponseStreamEvent: if self._async_iterator is None: self._async_iterator = self._get_async_iterator() return await self._async_iterator.__anext__() - async def _get_async_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + async def _get_async_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: if self._timestamp is None: self._timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index c8386d23..a4c2415b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -217,15 +217,13 @@ class TestStreamedResponse(StreamedResponse): _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent | None] | None = field(default=None, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - async def __anext__(self) -> ModelResponseStreamEvent | None: + async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() next_event = await self._event_iterator.__anext__() - if next_event is None: - return None if isinstance(next_event, PartStartEvent): self._parts[next_event.index] = next_event.part @@ -237,7 +235,7 @@ async def __anext__(self) -> ModelResponseStreamEvent | None: self._usage += _estimate_event_usage(next_event) return next_event - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent | None]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for i, part in enumerate(self._structured_response.parts): if isinstance(part, TextPart): text = part.content @@ -422,7 +420,7 @@ def _char(self) -> str: return s -def _estimate_event_usage(event: ModelResponseStreamEvent | None) -> Usage: +def _estimate_event_usage(event: ModelResponseStreamEvent) -> Usage: response_tokens = 0 if isinstance(event, PartStartEvent) and isinstance(event.part, TextPart): response_tokens = _estimate_string_usage(event.part.content) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 3176cc8f..088ee212 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -369,10 +369,10 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: def _get_usage_checking_stream_response( - stream_response: AsyncIterator[ModelResponseStreamEvent | None], + stream_response: AsyncIterator[ModelResponseStreamEvent], limits: UsageLimits | None, get_usage: Callable[[], Usage], -) -> AsyncIterator[ModelResponseStreamEvent | None]: +) -> AsyncIterator[ModelResponseStreamEvent]: if limits is not None and limits.has_token_limits(): async def _usage_checking_iterator(): From 5528ad1cb5029453b99bc1327fe104e1434a091f Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:56:27 -0700 Subject: [PATCH 14/34] Update example --- docs/multi-agent-applications.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index cda67859..90c2e552 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -148,9 +148,9 @@ async def main(): """ Usage( requests=4, - request_tokens=310, + request_tokens=309, response_tokens=32, - total_tokens=342, + total_tokens=341, details=None, ) """ From 0ff2f87a3813a4157c163087bd9468a8bf002f6d Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:08:09 -0700 Subject: [PATCH 15/34] Make PeekableAsyncStream work even if the stream can yield None --- pydantic_ai_slim/pydantic_ai/_utils.py | 20 +++++++++++-------- .../pydantic_ai/models/function.py | 18 ++++++----------- pydantic_ai_slim/pydantic_ai/models/groq.py | 6 +++--- .../pydantic_ai/models/mistral.py | 8 ++++---- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +++--- 5 files changed, 28 insertions(+), 30 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 04a4c2b0..53446afa 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -160,7 +160,7 @@ async def group_by_temporal( as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed Returns: - A context manager usable as an iterator async iterable of pairs of lists of items from the input async iterable, + A context manager usable as an async iterable of pairs of lists of items from the input async iterable, and a boolean indicating whether the item was final coming out of the iterator. """ if soft_max_interval is None: @@ -274,19 +274,19 @@ class PeekableAsyncStream(Generic[T]): def __init__(self, source: AsyncIterable[T]): self._source = source self._source_iter: AsyncIterator[T] | None = None - self._buffer: T | None = None + self._buffer: T | Unset = UNSET self._exhausted = False - async def peek(self) -> T | None: + async def peek(self) -> T | Unset: """Returns the next item that would be yielded without consuming it. Returns None if the stream is exhausted. """ if self._exhausted: - return None + return UNSET # If we already have a buffered item, just return it. - if self._buffer is not None: + if not isinstance(self._buffer, Unset): return self._buffer # Otherwise, we need to fetch the next item from the underlying iterator. @@ -297,10 +297,14 @@ async def peek(self) -> T | None: self._buffer = await self._source_iter.__anext__() except StopAsyncIteration: self._exhausted = True - return None + return UNSET return self._buffer + async def is_exhausted(self) -> bool: + """Returns True if the stream is exhausted, False otherwise.""" + return isinstance(await self.peek(), Unset) + def __aiter__(self) -> AsyncIterator[T]: # For a single-pass iteration, we can return self as the iterator. return self @@ -314,9 +318,9 @@ async def __anext__(self) -> T: raise StopAsyncIteration # If we have a buffered item, yield it. - if self._buffer is not None: + if not isinstance(self._buffer, Unset): item = self._buffer - self._buffer = None + self._buffer = UNSET return item # Otherwise, fetch the next item from the source. diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index b56875f6..bcd9f511 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -12,6 +12,7 @@ from typing_extensions import TypeAlias, assert_never, overload from .. import _utils, result +from .._utils import PeekableAsyncStream from ..messages import ( ModelMessage, ModelRequest, @@ -167,20 +168,13 @@ async def request_stream( assert ( self.stream_function is not None ), 'FunctionModel must receive a `stream_function` to support streamed requests' - response_stream = self.stream_function(messages, self.agent_info) + response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info)) - # Explicitly check that we get at least one value, so we can produce a nicer error message for misuse - try: - first = await response_stream.__anext__() - except StopAsyncIteration as e: - raise ValueError('Stream function must return at least one item') from e + first = await response_stream.peek() + if isinstance(first, _utils.Unset): + raise ValueError('Stream function must return at least one item') - async def peeked_stream(): - yield first - async for item in response_stream: - yield item - - yield FunctionStreamedResponse(peeked_stream()) + yield FunctionStreamedResponse(response_stream) @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 4eb742ef..839febdf 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -11,7 +11,7 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils, result -from .._utils import PeekableAsyncStream, guard_tool_call_id as _guard_tool_call_id +from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, ModelRequest, @@ -229,9 +229,9 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: @staticmethod async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" - peekable_response = PeekableAsyncStream(response) + peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() - if first_chunk is None: + if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') return GroqStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 9aaaf4d8..efdaf9e6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -12,8 +12,8 @@ from httpx import AsyncClient as AsyncHTTPClient, Timeout from typing_extensions import assert_never -from .. import UnexpectedModelBehavior -from .._utils import PeekableAsyncStream, now_utc as _now_utc +from .. import UnexpectedModelBehavior, _utils +from .._utils import now_utc as _now_utc from ..messages import ( ArgsJson, ModelMessage, @@ -299,9 +299,9 @@ async def _process_streamed_response( response: MistralEventStreamAsync[MistralCompletionEvent], ) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" - peekable_response = PeekableAsyncStream(response) + peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() - if first_chunk is None: + if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') if first_chunk.data.created: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 0c40353f..e252be71 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -11,7 +11,7 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils, result -from .._utils import PeekableAsyncStream, guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc +from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( ModelMessage, ModelRequest, @@ -219,9 +219,9 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: @staticmethod async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" - peekable_response = PeekableAsyncStream(response) + peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() - if first_chunk is None: + if isinstance(first_chunk, _utils.Unset): raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') return OpenAIStreamedResponse(peekable_response, datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)) From f21aaf02764d70a477d3916e2b2de612d23b9ae1 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 6 Jan 2025 16:23:24 -0700 Subject: [PATCH 16/34] Remove PartStopEvent --- pydantic_ai_slim/pydantic_ai/messages.py | 12 +---------- pydantic_ai_slim/pydantic_ai/models/gemini.py | 13 ++---------- pydantic_ai_slim/pydantic_ai/models/groq.py | 21 ++++--------------- .../pydantic_ai/models/mistral.py | 19 +++-------------- pydantic_ai_slim/pydantic_ai/models/test.py | 2 -- tests/models/test_ollama.py | 1 - 6 files changed, 10 insertions(+), 58 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index a7e40e4d..82eddb0c 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -315,14 +315,4 @@ class PartDeltaEvent: event_kind: Literal['part_delta'] = 'part_delta' -@dataclass -class PartStopEvent: - """A part stop event.""" - - index: int - event_kind: Literal['part_stop'] = 'part_stop' - - -ModelResponseStreamEvent = Annotated[ - Union[PartStartEvent, PartDeltaEvent, PartStopEvent], pydantic.Discriminator('event_kind') -] +ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')] diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index be330063..f609d90f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -337,12 +337,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # The following condition holds if and only if we are not already in a text part: if current_part_index is None or current_tool_call_name is not None: current_tool_call_name = None - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 - + current_part_index = 0 if current_part_index is None else current_part_index + 1 part = TextPart(gemini_part['text']) yield PartStartEvent(index=current_part_index, part=part) else: @@ -352,11 +347,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly # it would just be a bit more complicated. And we'd need to confirm the intended semantics. current_tool_call_name = gemini_part['function_call']['name'] - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 + current_part_index = 0 if current_part_index is None else current_part_index + 1 yield PartStartEvent( index=current_part_index, part=ToolCallPart.from_raw_args(current_tool_call_name, gemini_part['function_call']['args']), diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 839febdf..554157b6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -317,7 +317,7 @@ async def __anext__(self) -> ModelResponseStreamEvent: return next_event - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa C901 + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # TODO: Simplify this through the use of a StreamedPartsManager or whatever current_part_index: int | None = None @@ -335,12 +335,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if self._content_part_index is not None: yield PartDeltaEvent(index=self._content_part_index, delta=TextPartDelta(content)) else: - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 - + current_part_index = 0 if current_part_index is None else current_part_index + 1 self._content_part_index = current_part_index part = TextPart(content) yield PartStartEvent(index=current_part_index, part=part) @@ -366,11 +361,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: part_index = self._tool_call_index_to_part_index.get(dtc.index) if part_index is None: - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 + current_part_index = 0 if current_part_index is None else current_part_index + 1 self._tool_call_index_to_part_index[dtc.index] = current_part_index yield PartStartEvent( index=current_part_index, @@ -398,11 +389,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if dtc.function.name and dtc.function.arguments: # This is the first delta_tool_call we've received with this index - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 + current_part_index = 0 if current_part_index is None else current_part_index + 1 self._tool_call_index_to_part_index[dtc.index] = current_part_index yield PartStartEvent( index=current_part_index, diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index efdaf9e6..b4ff8008 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -504,11 +504,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools) if maybe_tool_call_part: if self._result_part_index is None: - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 + current_part_index = 0 if current_part_index is None else current_part_index + 1 self._result_part_index = current_part_index yield PartStartEvent( @@ -519,12 +515,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if self._content_part_index is not None: yield PartDeltaEvent(index=self._content_part_index, delta=TextPartDelta(text)) else: - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 - + current_part_index = 0 if current_part_index is None else current_part_index + 1 self._content_part_index = current_part_index part = TextPart(text) yield PartStartEvent(index=current_part_index, part=part) @@ -534,11 +525,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # It seems that mistral just sends full tool calls, so we just use them directly, rather than building part_index = self._tool_call_id_to_part_index.get(dtc.id) if part_index is None: - if current_part_index is None: - current_part_index = 0 - else: - # yield PartStopEvent(index=current_part_index) # TODO: Not sure if we want to keep these events.. - current_part_index += 1 + current_part_index = 0 if current_part_index is None else current_part_index + 1 self._tool_call_id_to_part_index[dtc.id] = current_part_index part_index = current_part_index yield PartStartEvent( diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index a4c2415b..051726b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -248,10 +248,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: yield PartStartEvent(index=i, part=TextPart(content='')) for word in words: yield PartDeltaEvent(index=i, delta=TextPartDelta(content_delta=word)) - # yield PartStopEvent(index=i, part=part) else: yield PartStartEvent(index=i, part=part) - # yield PartStopEvent(index=i, part=part) def get(self, *, final: bool = False) -> ModelResponse: parts = [self._parts[index] for index in sorted(self._parts)] diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py index c608c656..821c11f9 100644 --- a/tests/models/test_ollama.py +++ b/tests/models/test_ollama.py @@ -38,7 +38,6 @@ def test_init(): async def test_request_simple_success(allow_model_requests: None): c = completion_message(ChatCompletionMessage(content='world', role='assistant')) mock_client = MockOpenAI.create_mock(c) - print('here') m = OllamaModel('llama3.2', openai_client=mock_client, base_url=None) agent = Agent(m) From 5eb5079c05b1832bd1de9f83bf01f9bd655a9ce0 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 7 Jan 2025 07:05:26 -0700 Subject: [PATCH 17/34] Fix for python 3.9 --- pydantic_ai_slim/pydantic_ai/models/mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index b4ff8008..204d1fbf 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -443,7 +443,7 @@ def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]: assert_never(message) -MistralToolCallId = str | None +MistralToolCallId = Union[str, None] @dataclass From 68efa657d7cd6c70caca295f9139303724f34c23 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:59:31 -0700 Subject: [PATCH 18/34] Fix test --- pydantic_ai_slim/pydantic_ai/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 65610b5a..8c419498 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1101,7 +1101,7 @@ async def _handle_structured_response( final_result: _MarkFinalResult[RunResultData] | None = None parts: list[_messages.ModelRequestPart] = [] - if result_schema := result_schema: + if result_schema is not None: if match := result_schema.find_tool(tool_calls): call, result_tool = match try: @@ -1218,10 +1218,10 @@ async def _handle_streamed_response( if tool := self._function_tools.get(p.tool_name): tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) else: - parts.append(self._unknown_tool(p.tool_name, run_context)) + parts.append(self._unknown_tool(p.tool_name, run_context, result_schema)) if received_text and not tasks and not parts: - # Can only get here if self._allow_text_result is False + # Can only get here if self._allow_text_result returns `False` for the provided result_schema self._incr_result_retry(run_context) model_response = _messages.RetryPromptPart( content='Plain text responses are not permitted, please call one of the functions instead.', From f0a5f683c20dca590bdd1c31361997f15cfb24b0 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 7 Jan 2025 22:44:04 -0700 Subject: [PATCH 19/34] Add parts manager --- .../pydantic_ai/_parts_manager.py | 144 ++++++++++++++++++ pydantic_ai_slim/pydantic_ai/messages.py | 101 ++++++++++-- .../pydantic_ai/models/anthropic.py | 8 +- .../pydantic_ai/models/function.py | 61 ++------ pydantic_ai_slim/pydantic_ai/models/gemini.py | 63 ++++---- pydantic_ai_slim/pydantic_ai/models/groq.py | 110 +++---------- .../pydantic_ai/models/mistral.py | 81 ++++------ pydantic_ai_slim/pydantic_ai/models/openai.py | 107 ++++--------- pydantic_ai_slim/pydantic_ai/models/test.py | 10 +- pydantic_ai_slim/pydantic_ai/result.py | 23 ++- tests/models/test_groq.py | 8 +- tests/models/test_openai.py | 13 +- tests/test_agent.py | 30 ++-- tests/test_streaming.py | 85 ++++++++++- 14 files changed, 485 insertions(+), 359 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/_parts_manager.py diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py new file mode 100644 index 00000000..50478f60 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -0,0 +1,144 @@ +from __future__ import annotations as _annotations + +from collections.abc import Hashable +from dataclasses import dataclass, field +from typing import Any + +from pydantic_ai import UnexpectedModelBehavior +from pydantic_ai.messages import ( + ModelResponsePart, + ModelResponseStreamEvent, + PartDeltaEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ToolCallPart, + ToolCallPartDelta, +) + +VendorId = Hashable + + +ManagedPart = ModelResponsePart | ToolCallPartDelta + + +@dataclass +class ModelResponsePartsManager: + _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) + _parts: list[ManagedPart] = field(default_factory=list, init=False) + + def get_parts(self) -> list[ModelResponsePart]: + return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] + + def handle_text_delta(self, *, vendor_part_id: Hashable | None, content: str) -> ModelResponseStreamEvent | None: + # vendor_part_id=None means to use the latest part if it is a text part, otherwise make a new one + if not content: + return None + + existing_text_part_and_index: tuple[TextPart, int] | None = None + if vendor_part_id is None: + if self._parts: + latest_part = self._parts[-1] + part_index = len(self._parts) - 1 + if isinstance(latest_part, TextPart): + existing_text_part_and_index = latest_part, part_index + else: + part_index = self._vendor_id_to_part_index.get(vendor_part_id) + if part_index is not None: + existing_part = self._parts[part_index] + if not isinstance(existing_part, TextPart): + raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}') + existing_text_part_and_index = existing_part, part_index + + if existing_text_part_and_index is None: + new_part_index = len(self._parts) + part = TextPart(content=content) + if vendor_part_id is not None: + self._vendor_id_to_part_index[vendor_part_id] = new_part_index + self._parts.append(part) + return PartStartEvent(index=new_part_index, part=part) + else: + existing_text_part, part_index = existing_text_part_and_index + part_delta = TextPartDelta(content_delta=content) + self._parts[part_index] = part_delta.apply(existing_text_part) + return PartDeltaEvent(index=part_index, delta=part_delta) + + def handle_tool_call_delta( + self, + *, + vendor_part_id: Hashable | None, + tool_name: str | None, + args: str | dict[str, Any] | None, + tool_call_id: str | None, + ) -> ModelResponseStreamEvent | None: + # vendor_part_id=None means to use the latest part if it is a matching tool call part, otherwise make a new one + existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None + if vendor_part_id is None: + # If vendor_part_id is not provided, the tool_name must match the latest part to perform updates + if self._parts: + latest_part = self._parts[-1] + part_index = len(self._parts) - 1 + if ( + isinstance(latest_part, ToolCallPart) and (tool_name is None or latest_part.tool_name == tool_name) + ) or ( + isinstance(latest_part, ToolCallPartDelta) + and ( + tool_name is None + or latest_part.tool_name_delta is None + or latest_part.tool_name_delta == tool_name + ) + ): + existing_matching_part_and_index = latest_part, part_index + else: + part_index = self._vendor_id_to_part_index.get(vendor_part_id) + if part_index is not None: + existing_part = self._parts[part_index] + if not isinstance(existing_part, (ToolCallPartDelta, ToolCallPart)): + raise UnexpectedModelBehavior(f'Cannot apply a tool call delta to {existing_part=}') + existing_matching_part_and_index = existing_part, part_index + + if existing_matching_part_and_index is None: + delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) + part = delta.as_part() or delta + if vendor_part_id is not None: + self._vendor_id_to_part_index[vendor_part_id] = len(self._parts) + new_part_index = len(self._parts) + self._parts.append(part) + # Only emit a PartStartEvent if we have enough information to produce a full ToolCallPart + if isinstance(part, ToolCallPart): + return PartStartEvent(index=new_part_index, part=part) + else: + existing_part, part_index = existing_matching_part_and_index + delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) + updated_part = delta.apply(existing_part) + self._parts[part_index] = updated_part + if isinstance(updated_part, ToolCallPart): + if isinstance(existing_part, ToolCallPartDelta): + # In this case, we just upgraded a delta to a full part, so emit a PartStartEvent: + return PartStartEvent(index=part_index, part=updated_part) + else: + # In this case, we just updated an existing part, so emit a PartDeltaEvent: + return PartDeltaEvent(index=part_index, delta=delta) + + def handle_tool_call_part( + self, + *, + vendor_part_id: Hashable | None, + tool_name: str, + args: str | dict[str, Any], + tool_call_id: str | None = None, + ) -> ModelResponseStreamEvent: + new_part = ToolCallPart.from_raw_args(tool_name=tool_name, args=args, tool_call_id=tool_call_id) + if vendor_part_id is None: + new_part_index = len(self._parts) + self._parts.append(new_part) + else: + maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id) + if maybe_part_index is not None: + new_part_index = maybe_part_index + self._parts[new_part_index] = new_part + else: + new_part_index = len(self._parts) + self._parts.append(new_part) + self._vendor_id_to_part_index[vendor_part_id] = new_part_index + return PartStartEvent(index=new_part_index, part=new_part) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index e383bdb5..ad4cf826 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace from datetime import datetime -from typing import Annotated, Any, Literal, Union, cast +from typing import Annotated, Any, Literal, Union, cast, overload import pydantic import pydantic_core @@ -257,7 +257,7 @@ class ModelResponse: @classmethod def from_text(cls, content: str, timestamp: datetime | None = None) -> Self: - return cls([TextPart(content)], timestamp=timestamp or _now_utc()) + return cls([TextPart(content=content)], timestamp=timestamp or _now_utc()) @classmethod def from_tool_call(cls, tool_call: ToolCallPart) -> Self: @@ -276,6 +276,7 @@ class TextPartDelta: """A text part delta.""" content_delta: str + part_delta_kind: Literal['text'] = 'text' def apply(self, part: ModelResponsePart) -> TextPart: @@ -288,16 +289,94 @@ def apply(self, part: ModelResponsePart) -> TextPart: class ToolCallPartDelta: """A tool call part delta.""" - args_json_delta: str + tool_name_delta: str | None = None + args_delta: str | dict[str, Any] | None = None + tool_call_id: str | None = None + part_delta_kind: Literal['tool_call'] = 'tool_call' - def apply(self, part: ModelResponsePart) -> ToolCallPart: - if not isinstance(part, ToolCallPart): - raise ValueError('Cannot apply ToolCallPartDeltas to non-ToolCallParts') - if not isinstance(part.args, ArgsJson): - raise ValueError('Cannot apply deltas to non-JSON tool arguments') - updated_json = part.args.args_json + self.args_json_delta - return replace(part, args=ArgsJson(updated_json)) + def as_part(self) -> ToolCallPart | None: + """Converts to a ToolCallPart if the required information is present, otherwise returns None.""" + if self.tool_name_delta is None or self.args_delta is None: + return None + + return ToolCallPart.from_raw_args( + self.tool_name_delta, + self.args_delta, + self.tool_call_id, + ) + + @overload + def apply(self, part: ModelResponsePart) -> ToolCallPart: ... + + @overload + def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ... + + def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: + if isinstance(part, ToolCallPart): + return self._apply_to_part(part) + + if isinstance(part, ToolCallPartDelta): + return self._apply_to_delta(part) + + raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}') + + def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: + if self.tool_name_delta: + # I'm not sure how common it is to have deltas on the tool name, but I've handled it here for completeness + updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta + delta = replace(delta, tool_name_delta=updated_tool_name_delta) + + if isinstance(self.args_delta, str): + if isinstance(delta.args_delta, dict): + raise NotImplementedError('Cannot apply a JSON args delta to a dict args delta') + updated_args_delta = (delta.args_delta or '') + self.args_delta + delta = replace(delta, args_delta=updated_args_delta) + elif isinstance(self.args_delta, dict): + if isinstance(delta.args_delta, str): + raise NotImplementedError('Cannot apply a dict args delta to a JSON args delta') + updated_args_delta = {**(delta.args_delta or {}), **self.args_delta} + delta = replace(delta, args_delta=updated_args_delta) + + if self.tool_call_id: + # Don't treat tool_call_id as a delta, just replace it + if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id: + raise ValueError('Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one') + delta = replace(delta, tool_call_id=self.tool_call_id) + + # If we have enough data to create a full ToolCallPart, do so: + if delta.tool_name_delta is not None and delta.args_delta is not None: + return ToolCallPart.from_raw_args( + delta.tool_name_delta, + delta.args_delta, + delta.tool_call_id, + ) + + return delta + + def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart: + if self.tool_name_delta: + # I'm not sure how common it is to have deltas on the tool name, but I've handled it here for completeness + tool_name = part.tool_name + self.tool_name_delta + part = replace(part, tool_name=tool_name) + + if isinstance(self.args_delta, str): + if not isinstance(part.args, ArgsJson): + raise ValueError('Cannot apply deltas to non-JSON tool arguments') + updated_json = part.args.args_json + self.args_delta + part = replace(part, args=ArgsJson(updated_json)) + elif isinstance(self.args_delta, dict): + if not isinstance(part.args, ArgsDict): + raise ValueError('Cannot apply deltas to non-dict tool arguments') + updated_dict = {**(part.args.args_dict or {}), **self.args_delta} + part = replace(part, args=ArgsDict(updated_dict)) + + if self.tool_call_id: + # Don't treat tool_call_id as a delta, just replace it + if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id: + raise ValueError('Cannot apply a new tool_call_id to a ToolCallPart that already has one') + part = replace(part, tool_call_id=self.tool_call_id) + return part ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')] @@ -307,7 +386,7 @@ def apply(self, part: ModelResponsePart) -> ToolCallPart: class PartStartEvent: # TODO: Consider renaming to PartReplaceEvent, or somehow indicate full replacement is an option """If multiple PartStartEvents are received with the same index, the new one should fully replace the old one.""" - index: int # TODO: Consider replacing index here and below with part_id + index: int part: ModelResponsePart event_kind: Literal['part_start'] = 'part_start' diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 02995dcf..d1629f20 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -216,14 +216,14 @@ def _process_response(response: AnthropicMessage) -> ModelResponse: items: list[ModelResponsePart] = [] for item in response.content: if isinstance(item, TextBlock): - items.append(TextPart(item.text)) + items.append(TextPart(content=item.text)) else: assert isinstance(item, ToolUseBlock), 'unexpected item type' items.append( ToolCallPart.from_raw_args( - item.name, - cast(dict[str, Any], item.input), - item.id, + tool_name=item.name, + args=cast(dict[str, Any], item.input), + tool_call_id=item.id, ) ) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index bcd9f511..f84c161e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -12,21 +12,17 @@ from typing_extensions import TypeAlias, assert_never, overload from .. import _utils, result +from .._parts_manager import ModelResponsePartsManager from .._utils import PeekableAsyncStream from ..messages import ( ModelMessage, ModelRequest, ModelResponse, - ModelResponsePart, ModelResponseStreamEvent, - PartDeltaEvent, - PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, - TextPartDelta, ToolCallPart, - ToolCallPartDelta, ToolReturnPart, UserPromptPart, ) @@ -184,10 +180,7 @@ class FunctionStreamedResponse(StreamedResponse): _iter: AsyncIterator[str | DeltaToolCalls] _timestamp: datetime = field(default_factory=_utils.now_utc) - _next_part_index: int = field(default=0, init=False) - _content_part_index: int | None = field(default=None, init=False) - _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) - _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent: @@ -197,50 +190,24 @@ async def __anext__(self) -> ModelResponseStreamEvent: async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: - # TODO: Create a PartsStreamManager class that wraps - # _next_part_index, _content_part_index, _tool_call_index_to_part_index, and _parts - # and is reusable across the different StreamedResponse implementations. if isinstance(item, str): - text = item - if self._content_part_index is None: - content_part = TextPart(content=text) - self._content_part_index = self._next_part_index - self._parts[self._content_part_index] = content_part - self._next_part_index += 1 - yield PartStartEvent(index=self._content_part_index, part=content_part) - else: - content_part = self._parts[self._content_part_index] - assert isinstance(content_part, TextPart), 'The content part must be a text part' - delta = TextPartDelta(content_delta=text) - self._parts[self._content_part_index] = delta.apply(content_part) - yield PartDeltaEvent(index=self._content_part_index, delta=delta) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) + if maybe_event is not None: + yield maybe_event else: delta_tool_calls = item for dtc_index, delta_tool_call in delta_tool_calls.items(): - existing_part_index = self._tool_call_index_to_part_index.get(dtc_index) - if existing_part_index is None: - new_part_index = self._next_part_index - self._tool_call_index_to_part_index[dtc_index] = new_part_index - self._next_part_index += 1 - assert ( - delta_tool_call.name is not None - ), 'The first delta_tool_call with a given index must include a tool name' - part = ToolCallPart.from_raw_args( - tool_name=delta_tool_call.name, args=delta_tool_call.json_args or '' - ) - self._parts[new_part_index] = part - yield PartStartEvent(index=new_part_index, part=part) - else: - existing_part = self._parts[existing_part_index] - assert isinstance(existing_part, ToolCallPart), 'Cannot switch to tool call part mid-stream' - if delta_tool_call.json_args is not None: - delta = ToolCallPartDelta(delta_tool_call.json_args) - self._parts[existing_part_index] = delta.apply(existing_part) - yield PartDeltaEvent(index=existing_part_index, delta=delta) + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc_index, + tool_name=delta_tool_call.name, + args=delta_tool_call.json_args, + tool_call_id=None, + ) + if maybe_event is not None: + yield maybe_event def get(self, *, final: bool = False) -> ModelResponse: - parts = [self._parts[index] for index in sorted(self._parts)] - return ModelResponse(parts, timestamp=self._timestamp) + return ModelResponse(self._parts_manager.get_parts(), timestamp=self._timestamp) def usage(self) -> result.Usage: return _estimate_usage([self.get()]) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 9a51e7b3..6159d737 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -8,24 +8,23 @@ from dataclasses import dataclass, field from datetime import datetime from typing import Annotated, Any, Literal, Protocol, Union +from uuid import uuid4 import pydantic from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse from typing_extensions import NotRequired, TypedDict, assert_never from .. import UnexpectedModelBehavior, _utils, exceptions, result +from .._parts_manager import ModelResponsePartsManager from ..messages import ( ModelMessage, ModelRequest, ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - PartDeltaEvent, - PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, - TextPartDelta, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -307,51 +306,41 @@ class GeminiStreamedResponse(StreamedResponse): _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _usage: result.Usage = field(default_factory=result.Usage, init=False) - _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() - - next_event = await self._event_iterator.__anext__() - - if isinstance(next_event, PartStartEvent): - self._parts[next_event.index] = next_event.part - elif isinstance(next_event, PartDeltaEvent): - existing_part = self._parts.get(next_event.index) - assert existing_part is not None, 'PartDeltaEvent without existing part' - self._parts[next_event.index] = next_event.delta.apply(self._parts[next_event.index]) - - return next_event + return await self._event_iterator.__anext__() async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - current_part_index: int | None = None - current_tool_call_name: str | None = None # None means we are in a text part or have no parts at all - async for gemini_response in self._get_gemini_responses(): candidate = gemini_response['candidates'][0] gemini_part: _GeminiPartUnion for gemini_part in candidate['content']['parts']: if 'text' in gemini_part: - # The following condition holds if and only if we are not already in a text part: - if current_part_index is None or current_tool_call_name is not None: - current_tool_call_name = None - current_part_index = 0 if current_part_index is None else current_part_index + 1 - part = TextPart(gemini_part['text']) - yield PartStartEvent(index=current_part_index, part=part) - else: - yield PartDeltaEvent(index=current_part_index, delta=TextPartDelta(gemini_part['text'])) + # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled + # amongst the tool call deltas + maybe_event = self._parts_manager.handle_text_delta( + vendor_part_id=None, content=gemini_part['text'] + ) + if maybe_event is not None: + yield maybe_event + elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. + # We do this by assigning a unique randomly generated "vendor_part_id". # We need to confirm whether this is actually true, but if it isn't, we can still handle it properly # it would just be a bit more complicated. And we'd need to confirm the intended semantics. - current_tool_call_name = gemini_part['function_call']['name'] - current_part_index = 0 if current_part_index is None else current_part_index + 1 - yield PartStartEvent( - index=current_part_index, - part=ToolCallPart.from_raw_args(current_tool_call_name, gemini_part['function_call']['args']), + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=uuid4(), + tool_name=gemini_part['function_call']['name'], + args=gemini_part['function_call']['args'], + tool_call_id=None, ) + if maybe_event is not None: + yield maybe_event else: assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}' @@ -388,7 +377,8 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: def get(self, *, final: bool = False) -> ModelResponse: """Get the `ModelResponse` at this point.""" - return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self._timestamp) + parts = self._parts_manager.get_parts() + return ModelResponse(parts=parts, timestamp=self._timestamp) def usage(self) -> result.Usage: return self._usage @@ -468,9 +458,14 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d items: list[ModelResponsePart] = [] for part in parts: if 'text' in part: - items.append(TextPart(part['text'])) + items.append(TextPart(content=part['text'])) elif 'function_call' in part: - items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args'])) + items.append( + ToolCallPart.from_raw_args( + tool_name=part['function_call']['name'], + args=part['function_call']['args'], + ) + ) elif 'function_response' in part: raise exceptions.UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 554157b6..1a0f2109 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -11,6 +11,7 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils, result +from .._parts_manager import ModelResponsePartsManager from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, @@ -18,14 +19,10 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - PartDeltaEvent, - PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, - TextPartDelta, ToolCallPart, - ToolCallPartDelta, ToolReturnPart, UserPromptPart, ) @@ -44,7 +41,6 @@ from groq import NOT_GIVEN, AsyncGroq, AsyncStream from groq.types import chat from groq.types.chat import ChatCompletion, ChatCompletionChunk - from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall except ImportError as _import_error: raise ImportError( 'Please install `groq` to use the Groq model, ' @@ -220,10 +216,12 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: choice = response.choices[0] items: list[ModelResponsePart] = [] if choice.message.content is not None: - items.append(TextPart(choice.message.content)) + items.append(TextPart(content=choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: - items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id)) + items.append( + ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id) + ) return ModelResponse(items, timestamp=timestamp) @staticmethod @@ -292,35 +290,17 @@ class GroqStreamedResponse(StreamedResponse): _response: AsyncIterable[ChatCompletionChunk] _timestamp: datetime - _usage: result.Usage = field(default_factory=result.Usage, init=False) - _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) - - _content_part_index: int | None = field(default=None, init=False) - _tool_call_index_to_part_index: dict[int, int] = field(default_factory=dict, init=False) - _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() - - next_event = await self._event_iterator.__anext__() - - if isinstance(next_event, PartStartEvent): - self._parts[next_event.index] = next_event.part - elif isinstance(next_event, PartDeltaEvent): - existing_part = self._parts.get(next_event.index) - assert existing_part is not None, 'PartDeltaEvent without existing part' - self._parts[next_event.index] = next_event.delta.apply(self._parts[next_event.index]) - - return next_event + return await self._event_iterator.__anext__() async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - # TODO: Simplify this through the use of a StreamedPartsManager or whatever - current_part_index: int | None = None - async for chunk in self._response: self._usage += _map_usage(chunk) @@ -332,74 +312,24 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: - if self._content_part_index is not None: - yield PartDeltaEvent(index=self._content_part_index, delta=TextPartDelta(content)) - else: - current_part_index = 0 if current_part_index is None else current_part_index + 1 - self._content_part_index = current_part_index - part = TextPart(content) - yield PartStartEvent(index=current_part_index, part=part) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) + if maybe_event is not None: + yield maybe_event # Handle the tool calls for dtc in choice.delta.tool_calls or []: - if not dtc.function: - continue - - if existing := self._delta_tool_calls.get(dtc.index): - # We've already received a delta_tool_call for this index - existing.id = existing.id or dtc.id - if existing.function is None: - existing.function = dtc.function - else: - existing.function.name = _utils.add_optional(existing.function.name, dtc.function.name) - existing.function.arguments = _utils.add_optional( - existing.function.arguments, dtc.function.arguments - ) - - if not (existing.function.name and existing.function.arguments): - continue # We don't have enough information to create a part - - part_index = self._tool_call_index_to_part_index.get(dtc.index) - if part_index is None: - current_part_index = 0 if current_part_index is None else current_part_index + 1 - self._tool_call_index_to_part_index[dtc.index] = current_part_index - yield PartStartEvent( - index=current_part_index, - part=ToolCallPart.from_raw_args( - existing.function.name, existing.function.arguments, tool_call_id=existing.id - ), - ) - elif dtc.function.name: - # We don't currently nicely support streaming updates to the function call name - # So we just replace the whole part if the name has changed - yield PartStartEvent( - index=part_index, - part=ToolCallPart.from_raw_args( - existing.function.name, existing.function.arguments, tool_call_id=existing.id - ), - ) - elif dtc.function.arguments: - yield PartDeltaEvent( - index=part_index, - delta=ToolCallPartDelta(dtc.function.arguments), - ) - - else: - self._delta_tool_calls[dtc.index] = dtc - - if dtc.function.name and dtc.function.arguments: - # This is the first delta_tool_call we've received with this index - current_part_index = 0 if current_part_index is None else current_part_index + 1 - self._tool_call_index_to_part_index[dtc.index] = current_part_index - yield PartStartEvent( - index=current_part_index, - part=ToolCallPart.from_raw_args( - dtc.function.name, dtc.function.arguments, tool_call_id=dtc.id - ), - ) + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc.index, + tool_name=dtc.function and dtc.function.name, + args=dtc.function and dtc.function.arguments, + tool_call_id=dtc.id, + ) + if maybe_event is not None: + yield maybe_event def get(self, *, final: bool = False) -> ModelResponse: - return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self._timestamp) + parts = self._parts_manager.get_parts() + return ModelResponse(parts=parts, timestamp=self._timestamp) def usage(self) -> Usage: return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 204d1fbf..ea8e0ea0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -13,6 +13,7 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils +from .._parts_manager import ModelResponsePartsManager from .._utils import now_utc as _now_utc from ..messages import ( ArgsJson, @@ -21,12 +22,9 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - PartDeltaEvent, - PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, - TextPartDelta, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -284,11 +282,11 @@ def _process_response(response: MistralChatCompletionResponse) -> ModelResponse: parts: list[ModelResponsePart] = [] if text := _map_content(content): - parts.append(TextPart(text)) + parts.append(TextPart(content=text)) if isinstance(tool_calls, list): for tool_call in tool_calls: - tool = _map_mistral_to_pydantic_tool_call(tool_call) + tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call) parts.append(tool) return ModelResponse(parts, timestamp=timestamp) @@ -456,36 +454,24 @@ class MistralStreamedResponse(StreamedResponse): _usage: Usage = field(default_factory=Usage, init=False) - _function_tools: dict[str, MistralToolCall] = field(default_factory=dict, init=False) + # _function_tools: dict[str, MistralToolCall] = field(default_factory=dict, init=False) _delta_content: str = '' + # _delta_tool_calls: dict[MistralToolCallId, MistralToolCall] = field(default_factory=dict, init=False) + # _result_part_index: int | None = field(default=None, init=False) + # _content_part_index: int | None = field(default=None, init=False) + # _tool_call_id_to_part_index: dict[MistralToolCallId, int] = field(default_factory=dict, init=False) + # _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) - _delta_tool_calls: dict[MistralToolCallId, MistralToolCall] = field(default_factory=dict, init=False) - - _result_part_index: int | None = field(default=None, init=False) - _content_part_index: int | None = field(default=None, init=False) - _tool_call_id_to_part_index: dict[MistralToolCallId, int] = field(default_factory=dict, init=False) - _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent: if self._event_iterator is None: self._event_iterator = self._get_event_iterator() - next_event = await self._event_iterator.__anext__() - - if isinstance(next_event, PartStartEvent): - self._parts[next_event.index] = next_event.part - elif isinstance(next_event, PartDeltaEvent): - existing_part = self._parts.get(next_event.index) - assert existing_part is not None, 'PartDeltaEvent without existing part' - self._parts[next_event.index] = next_event.delta.apply(self._parts[next_event.index]) - - return next_event + return await self._event_iterator.__anext__() async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - # TODO: Simplify this through the use of a StreamedPartsManager or whatever - current_part_index: int | None = None - chunk: MistralCompletionEvent async for chunk in self._response: self._usage += _map_usage(chunk.data) @@ -499,42 +485,32 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: content = choice.delta.content text = _map_content(content) if text: + # Attempt to produce a result tool call from the received text if self._result_tools: self._delta_content += text maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools) if maybe_tool_call_part: - if self._result_part_index is None: - current_part_index = 0 if current_part_index is None else current_part_index + 1 - self._result_part_index = current_part_index - - yield PartStartEvent( - index=self._result_part_index, - part=maybe_tool_call_part, + yield self._parts_manager.handle_tool_call_part( + vendor_part_id='result', + tool_name=maybe_tool_call_part.tool_name, + args=maybe_tool_call_part.args_as_dict(), + tool_call_id=maybe_tool_call_part.tool_call_id, ) else: - if self._content_part_index is not None: - yield PartDeltaEvent(index=self._content_part_index, delta=TextPartDelta(text)) - else: - current_part_index = 0 if current_part_index is None else current_part_index + 1 - self._content_part_index = current_part_index - part = TextPart(text) - yield PartStartEvent(index=current_part_index, part=part) - - # Handle the tool calls - for dtc in choice.delta.tool_calls or []: + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) + if maybe_event is not None: + yield maybe_event + + # Handle the explicit tool calls + for index, dtc in enumerate(choice.delta.tool_calls or []): # It seems that mistral just sends full tool calls, so we just use them directly, rather than building - part_index = self._tool_call_id_to_part_index.get(dtc.id) - if part_index is None: - current_part_index = 0 if current_part_index is None else current_part_index + 1 - self._tool_call_id_to_part_index[dtc.id] = current_part_index - part_index = current_part_index - yield PartStartEvent( - index=part_index, - part=ToolCallPart.from_raw_args(dtc.function.name, dtc.function.arguments, tool_call_id=dtc.id), + yield self._parts_manager.handle_tool_call_part( + vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id ) def get(self, *, final: bool = False) -> ModelResponse: - return ModelResponse(parts=[self._parts[k] for k in sorted(self._parts)], timestamp=self._timestamp) + parts = self._parts_manager.get_parts() + return ModelResponse(parts=parts, timestamp=self._timestamp) def usage(self) -> Usage: return self._usage @@ -555,7 +531,8 @@ def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefini ): continue - return ToolCallPart.from_raw_args(result_tool.name, output_json) + # The following part_id will be thrown away + return ToolCallPart.from_raw_args(tool_name=result_tool.name, args=output_json) @staticmethod def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index e252be71..76243fcd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator +from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone @@ -11,6 +11,7 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils, result +from .._parts_manager import ModelResponsePartsManager from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( ModelMessage, @@ -18,14 +19,10 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - PartDeltaEvent, - PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, - TextPartDelta, ToolCallPart, - ToolCallPartDelta, ToolReturnPart, UserPromptPart, ) @@ -44,7 +41,6 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream from openai.types import ChatModel, chat from openai.types.chat import ChatCompletionChunk - from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall except ImportError as _import_error: raise ImportError( 'Please install `openai` to use the OpenAI model, ' @@ -278,9 +274,6 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio assert_never(part) -_CONTENT_INDEX = 0 - - @dataclass class OpenAIStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for OpenAI models.""" @@ -290,58 +283,43 @@ class OpenAIStreamedResponse(StreamedResponse): _usage: result.Usage = field(default_factory=result.Usage, init=False) - _delta_tool_calls: dict[int, ChoiceDeltaToolCall] = field(default_factory=dict, init=False) - _content_part: TextPart | None = field(default=None, init=False) - _tool_call_parts: dict[int, ToolCallPart] = field(default_factory=dict, init=False) - _async_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - - def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: - if self._async_iterator is None: - self._async_iterator = self._get_async_iterator() - return self._async_iterator + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) + _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent: - if self._async_iterator is None: - self._async_iterator = self._get_async_iterator() - return await self._async_iterator.__anext__() + if self._event_iterator is None: + self._event_iterator = self._get_events_iterator() + return await self._event_iterator.__anext__() - async def _get_async_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_events_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: - if self._timestamp is None: - self._timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc) - self._usage += _map_usage(chunk) + try: choice = chunk.choices[0] except IndexError: continue - for e in self._update_parts_for_content_delta(choice.delta.content): - yield e - - for new in choice.delta.tool_calls or []: - if current := self._delta_tool_calls.get(new.index): - if new.function is not None: - if current.function is None: - replace_existing_part = True - current.function = new.function - else: - replace_existing_part = bool(new.function.name) - current.function.name = _utils.add_optional(current.function.name, new.function.name) - current.function.arguments = _utils.add_optional( - current.function.arguments, new.function.arguments - ) - for e in self._update_parts_for_tool_call_delta(new, replace_existing_part): - yield e - else: - self._delta_tool_calls[new.index] = new - for e in self._update_parts_for_tool_call_delta(new, True): - yield e + # Handle the text part of the response + content = choice.delta.content + if content is not None: + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) + if maybe_event is not None: + yield maybe_event + + for dtc in choice.delta.tool_calls or []: + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc.index, + tool_name=dtc.function and dtc.function.name, + args=dtc.function and dtc.function.arguments, + tool_call_id=dtc.id, + ) + if maybe_event is not None: + yield maybe_event def get(self, *, final: bool = False) -> ModelResponse: - items: list[ModelResponsePart] = [self._content_part] if self._content_part is not None else [] - items.extend([self._tool_call_parts[k] for k in sorted(self._tool_call_parts.keys())]) - return ModelResponse(items, timestamp=self.timestamp()) + parts = self._parts_manager.get_parts() + return ModelResponse(parts=parts, timestamp=self._timestamp) def usage(self) -> Usage: return self._usage @@ -349,37 +327,6 @@ def usage(self) -> Usage: def timestamp(self) -> datetime: return self._timestamp or _now_utc() - def _update_parts_for_content_delta(self, choice_delta_content: str | None) -> Iterator[ModelResponseStreamEvent]: - if choice_delta_content is None: - return - - existing_content = self._content_part - if existing_content is None: - part = TextPart(content=choice_delta_content) - self._content_part = part - yield PartStartEvent(index=_CONTENT_INDEX, part=part) - else: - delta = TextPartDelta(content_delta=choice_delta_content) - self._content_part = delta.apply(existing_content) - yield PartDeltaEvent(index=_CONTENT_INDEX, delta=delta) - - def _update_parts_for_tool_call_delta( - self, tc: ChoiceDeltaToolCall, replace: bool - ) -> Iterator[ModelResponseStreamEvent]: - if tc.function is None: - return - - if replace: - assert tc.function.name is not None - new_part = ToolCallPart.from_raw_args(tc.function.name, tc.function.arguments or '', tc.id) - self._tool_call_parts[tc.index] = new_part - yield PartStartEvent(index=tc.index, part=new_part) - else: - assert (existing_part := self._tool_call_parts.get(tc.index)) is not None - delta = ToolCallPartDelta(args_json_delta=tc.function.arguments or '') - self._tool_call_parts[tc.index] = delta.apply(existing_part) - yield PartDeltaEvent(index=tc.index, delta=delta) - def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: return chat.ChatCompletionMessageToolCallParam( diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 051726b8..1f9a90e0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -216,7 +216,7 @@ class TestStreamedResponse(StreamedResponse): _usage: Usage _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _parts_by_index: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent: @@ -226,11 +226,11 @@ async def __anext__(self) -> ModelResponseStreamEvent: next_event = await self._event_iterator.__anext__() if isinstance(next_event, PartStartEvent): - self._parts[next_event.index] = next_event.part + self._parts_by_index[next_event.index] = next_event.part elif isinstance(next_event, PartDeltaEvent): - existing_part = self._parts.get(next_event.index) + existing_part = self._parts_by_index.get(next_event.index) assert existing_part is not None, 'PartDeltaEvent without existing part' - self._parts[next_event.index] = next_event.delta.apply(existing_part) + self._parts_by_index[next_event.index] = next_event.delta.apply(existing_part) self._usage += _estimate_event_usage(next_event) return next_event @@ -252,7 +252,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: yield PartStartEvent(index=i, part=part) def get(self, *, final: bool = False) -> ModelResponse: - parts = [self._parts[index] for index in sorted(self._parts)] + parts = [self._parts_by_index[index] for index in sorted(self._parts_by_index)] return ModelResponse(parts, timestamp=self._timestamp) def usage(self) -> Usage: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 088ee212..47c73887 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -203,7 +203,8 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu An async iterable of the response data. """ async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): - yield await self.validate_structured_result(structured_message, allow_partial=not is_last) + result = await self.validate_structured_result(structured_message, allow_partial=not is_last) + yield result async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. @@ -227,26 +228,24 @@ async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: # TODO: This needs to be rolled into the group_by_temporal below msg = self._stream_response.get() for i, part in enumerate(msg.parts): - # TODO: Probably need to replace this usage of index with a (tracked) part ID or similar - # (It's not guaranteed that this index `i` matches what comes out of the maybe_event.index below) if isinstance(part, TextPart) and part.content: yield part.content, i async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: async for events, _is_final in group_iter: - for maybe_event in events: + for event in events: if ( - isinstance(maybe_event, _messages.PartStartEvent) - and isinstance(maybe_event.part, _messages.TextPart) - and maybe_event.part.content + isinstance(event, _messages.PartStartEvent) + and isinstance(event.part, _messages.TextPart) + and event.part.content ): - yield maybe_event.part.content, maybe_event.index + yield event.part.content, event.index elif ( - isinstance(maybe_event, _messages.PartDeltaEvent) - and isinstance(maybe_event.delta, _messages.TextPartDelta) - and maybe_event.delta.content_delta + isinstance(event, _messages.PartDeltaEvent) + and isinstance(event.delta, _messages.TextPartDelta) + and event.delta.content_delta ): - yield maybe_event.delta.content_delta, maybe_event.index + yield event.delta.content_delta, event.index with _logfire.span('response stream text') as lf_span: if delta: diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 4fd250b4..6f3522d6 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -298,7 +298,7 @@ async def get_location(loc_name: str) -> str: tool_call_id='2', ) ], - timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), ModelRequest( parts=[ @@ -419,7 +419,10 @@ async def test_stream_structured(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ - ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"first": "One", "second": "Two"}')) + ToolCallPart( + tool_name='final_result', + args=ArgsJson(args_json='{"first": "One", "second": "Two"}'), + ) ], timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), @@ -456,6 +459,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, ] ) assert result.is_complete diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index f46f16f9..995984b3 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -13,7 +13,6 @@ from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, _utils from pydantic_ai.messages import ( - ArgsJson, ModelRequest, ModelResponse, RetryPromptPart, @@ -197,9 +196,9 @@ async def test_request_structured_response(allow_model_requests: None): ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[ - ToolCallPart( + ToolCallPart.from_raw_args( tool_name='final_result', - args=ArgsJson(args_json='{"response": [1, 2, 123]}'), + args='{"response": [1, 2, 123]}', tool_call_id='123', ) ], @@ -284,9 +283,9 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[ - ToolCallPart( + ToolCallPart.from_raw_args( tool_name='get_location', - args=ArgsJson(args_json='{"loc_name": "San Fransisco"}'), + args='{"loc_name": "San Fransisco"}', tool_call_id='1', ) ], @@ -304,9 +303,9 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[ - ToolCallPart( + ToolCallPart.from_raw_args( tool_name='get_location', - args=ArgsJson(args_json='{"loc_name": "London"}'), + args='{"loc_name": "London"}', tool_call_id='2', ) ], diff --git a/tests/test_agent.py b/tests/test_agent.py index ed03e1b7..f3589ad7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -260,7 +260,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ] ), ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"response": ["foo", "bar"]}'))], + parts=[ToolCallPart.from_raw_args(tool_name='final_result', args='{"response": ["foo", "bar"]}')], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -510,7 +510,7 @@ async def ret_a(x: str) -> str: ] ), ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], + parts=[ToolCallPart.from_raw_args(tool_name='ret_a', args={'x': 'a'})], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -622,7 +622,13 @@ async def ret_a(x: str) -> str: parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] ), ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'a': 0}), tool_call_id=None)], + parts=[ + ToolCallPart( + tool_name='final_result', + args=ArgsDict(args_dict={'a': 0}), + tool_call_id=None, + ) + ], timestamp=IsNow(tz=timezone.utc), ), ModelRequest( @@ -1016,11 +1022,11 @@ def another_tool(y: int) -> int: ), ModelResponse( parts=[ - ToolCallPart(tool_name='regular_tool', args=ArgsDict(args_dict={'x': 42})), - ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'value': 'first'})), - ToolCallPart(tool_name='another_tool', args=ArgsDict(args_dict={'y': 2})), - ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'value': 'second'})), - ToolCallPart(tool_name='unknown_tool', args=ArgsDict(args_dict={'value': '???'})), + ToolCallPart.from_raw_args(tool_name='regular_tool', args={'x': 42}), + ToolCallPart.from_raw_args(tool_name='final_result', args={'value': 'first'}), + ToolCallPart.from_raw_args(tool_name='another_tool', args={'y': 2}), + ToolCallPart.from_raw_args(tool_name='final_result', args={'value': 'second'}), + ToolCallPart.from_raw_args(tool_name='unknown_tool', args={'value': '???'}), ], timestamp=IsNow(tz=timezone.utc), ), @@ -1093,10 +1099,10 @@ def another_tool(y: int) -> int: # pragma: no cover ), ModelResponse( parts=[ - ToolCallPart(tool_name='regular_tool', args=ArgsDict(args_dict={'x': 1})), - ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'value': 'final'})), - ToolCallPart(tool_name='another_tool', args=ArgsDict(args_dict={'y': 2})), - ToolCallPart(tool_name='unknown_tool', args=ArgsDict(args_dict={'value': '???'})), + ToolCallPart.from_raw_args(tool_name='regular_tool', args={'x': 1}), + ToolCallPart.from_raw_args(tool_name='final_result', args={'value': 'final'}), + ToolCallPart.from_raw_args(tool_name='another_tool', args={'y': 2}), + ToolCallPart.from_raw_args(tool_name='unknown_tool', args={'value': '???'}), ], timestamp=IsNow(tz=timezone.utc), ), diff --git a/tests/test_streaming.py b/tests/test_streaming.py index f0dae823..1c2d3b49 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import datetime import json from collections.abc import AsyncIterator from datetime import timezone @@ -503,7 +504,6 @@ def another_tool(y: int) -> int: ) -@pytest.mark.xfail(reason='final result tool not first is not yet supported') async def test_early_strategy_with_final_result_in_middle(): """Test that 'early' strategy stops at first final result, regardless of position.""" tool_called: list[str] = [] @@ -531,14 +531,93 @@ def another_tool(y: int) -> int: # pragma: no cover async with agent.run_stream('test early strategy with final result in middle') as result: response = await result.get_data() - assert response.value == snapshot('first') + assert response.value == snapshot('final') messages = result.all_messages() # Verify no tools were called assert tool_called == [] # Verify we got appropriate tool returns - assert messages == snapshot() + assert messages == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='test early strategy with final ' 'result in middle', + timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='user-prompt', + ) + ], + kind='request', + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='regular_tool', + args=ArgsJson(args_json='{"x": 1}'), + tool_call_id=None, + part_kind='tool-call', + ), + ToolCallPart( + tool_name='final_result', + args=ArgsJson(args_json='{"value": "final"}'), + tool_call_id=None, + part_kind='tool-call', + ), + ToolCallPart( + tool_name='another_tool', + args=ArgsJson(args_json='{"y": 2}'), + tool_call_id=None, + part_kind='tool-call', + ), + ToolCallPart( + tool_name='unknown_tool', + args=ArgsJson(args_json='{"value": "???"}'), + tool_call_id=None, + part_kind='tool-call', + ), + ], + timestamp=IsNow(tz=datetime.timezone.utc), + kind='response', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='regular_tool', + content='Tool not executed - a final ' 'result was already processed.', + tool_call_id=None, + timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='tool-return', + ), + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=None, + timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='tool-return', + ), + ToolReturnPart( + tool_name='another_tool', + content='Tool not executed - a final ' 'result was already processed.', + tool_call_id=None, + timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='tool-return', + ), + RetryPromptPart( + content='Unknown tool name: ' + "'unknown_tool'. Available tools: " + 'regular_tool, another_tool, ' + 'final_result', + tool_name=None, + tool_call_id=None, + timestamp=IsNow(tz=datetime.timezone.utc), + part_kind='retry-prompt', + ), + ], + kind='request', + ), + ] + ) async def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(): From e024fe416659a6e1ddd8fae98279d1ca13665c2e Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 7 Jan 2025 22:57:16 -0700 Subject: [PATCH 20/34] Fix syntax issue for 3.9 --- pydantic_ai_slim/pydantic_ai/_parts_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 50478f60..e0ae01c0 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -2,7 +2,7 @@ from collections.abc import Hashable from dataclasses import dataclass, field -from typing import Any +from typing import Any, Union from pydantic_ai import UnexpectedModelBehavior from pydantic_ai.messages import ( @@ -19,7 +19,7 @@ VendorId = Hashable -ManagedPart = ModelResponsePart | ToolCallPartDelta +ManagedPart = Union[ModelResponsePart, ToolCallPartDelta] @dataclass From 8b712af8b94a97ef32aad3a40b924abac93793c8 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 00:04:58 -0700 Subject: [PATCH 21/34] A bit more clean-up --- .../pydantic_ai/_parts_manager.py | 3 -- pydantic_ai_slim/pydantic_ai/messages.py | 2 +- pydantic_ai_slim/pydantic_ai/models/test.py | 29 ++++++++++--------- pydantic_ai_slim/pydantic_ai/result.py | 2 -- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index e0ae01c0..005ba753 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -32,9 +32,6 @@ def get_parts(self) -> list[ModelResponsePart]: def handle_text_delta(self, *, vendor_part_id: Hashable | None, content: str) -> ModelResponseStreamEvent | None: # vendor_part_id=None means to use the latest part if it is a text part, otherwise make a new one - if not content: - return None - existing_text_part_and_index: tuple[TextPart, int] | None = None if vendor_part_id is None: if self._parts: diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index ad4cf826..84ca6096 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -383,7 +383,7 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart: @dataclass -class PartStartEvent: # TODO: Consider renaming to PartReplaceEvent, or somehow indicate full replacement is an option +class PartStartEvent: """If multiple PartStartEvents are received with the same index, the new one should fully replace the old one.""" index: int diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 1f9a90e0..728ae7ac 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -11,7 +11,9 @@ import pydantic_core from .. import _utils +from .._parts_manager import ModelResponsePartsManager from ..messages import ( + ArgsJson, ModelMessage, ModelRequest, ModelResponse, @@ -216,7 +218,7 @@ class TestStreamedResponse(StreamedResponse): _usage: Usage _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - _parts_by_index: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) async def __anext__(self) -> ModelResponseStreamEvent: @@ -225,14 +227,8 @@ async def __anext__(self) -> ModelResponseStreamEvent: next_event = await self._event_iterator.__anext__() - if isinstance(next_event, PartStartEvent): - self._parts_by_index[next_event.index] = next_event.part - elif isinstance(next_event, PartDeltaEvent): - existing_part = self._parts_by_index.get(next_event.index) - assert existing_part is not None, 'PartDeltaEvent without existing part' - self._parts_by_index[next_event.index] = next_event.delta.apply(existing_part) - - self._usage += _estimate_event_usage(next_event) + update = _estimate_event_usage(next_event) + self._usage += update return next_event async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: @@ -245,14 +241,21 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if len(words) == 1 and len(text) > 2: mid = len(text) // 2 words = [text[:mid], text[mid:]] - yield PartStartEvent(index=i, part=TextPart(content='')) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='') + if maybe_event is not None: + yield maybe_event for word in words: - yield PartDeltaEvent(index=i, delta=TextPartDelta(content_delta=word)) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) + if maybe_event is not None: + yield maybe_event else: - yield PartStartEvent(index=i, part=part) + args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict + yield self._parts_manager.handle_tool_call_part( + vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id + ) def get(self, *, final: bool = False) -> ModelResponse: - parts = [self._parts_by_index[index] for index in sorted(self._parts_by_index)] + parts = self._parts_manager.get_parts() return ModelResponse(parts, timestamp=self._timestamp) def usage(self) -> Usage: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 47c73887..fec69e64 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -225,7 +225,6 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: # if the response currently has any parts with content, yield those before streaming - # TODO: This needs to be rolled into the group_by_temporal below msg = self._stream_response.get() for i, part in enumerate(msg.parts): if isinstance(part, TextPart) and part.content: @@ -284,7 +283,6 @@ async def stream_structured( with _logfire.span('response stream structured') as lf_span: # if the message currently has any parts with content, yield before streaming - # TODO: This needs to be rolled into the group_by_temporal below... msg = self._stream_response.get() for part in msg.parts: if part.has_content(): From f647a3810d4798c576184f0fcf6c7adf4860818e Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 12:50:02 -0700 Subject: [PATCH 22/34] Document the parts manager --- .../pydantic_ai/_parts_manager.py | 122 +++++++++++++++++- .../pydantic_ai/models/function.py | 4 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 4 +- .../pydantic_ai/models/mistral.py | 4 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 +- pydantic_ai_slim/pydantic_ai/models/test.py | 8 +- 7 files changed, 122 insertions(+), 30 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 005ba753..201c41af 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -1,3 +1,15 @@ +"""This module provides functionality to manage and update parts of a model's streamed response. + +The manager tracks which parts (in particular, text and tool calls) correspond to which +vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model), +and produces PydanticAI-format events as appropriate for consumers of the streaming APIs. + +The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor, +and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation. + +This `PartsManager` is used in each of the subclasses of StreamedResponse as a way to consolidate event-emitting logic. +""" + from __future__ import annotations as _annotations from collections.abc import Hashable @@ -17,29 +29,78 @@ ) VendorId = Hashable - +""" +Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.) +""" ManagedPart = Union[ModelResponsePart, ToolCallPartDelta] +""" +A union of types that are managed by the ModelResponsePartsManager. +Because many vendors have streaming APIs that may produce not-fully-formed tool calls, +this includes ToolCallPartDelta's in addition to the more fully-formed ModelResponsePart's. +""" @dataclass class ModelResponsePartsManager: + """Manages a sequence of parts that make up a model's streamed response. + + Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs. + + Attributes: + _vendor_id_to_part_index: Maps a vendor's "part" ID (if provided) to the index + in the `_parts` list where that part (or its delta) resides. + _parts: A list of parts (text or tool calls) that make up the + current state of the model's response. + """ + _vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False) _parts: list[ManagedPart] = field(default_factory=list, init=False) def get_parts(self) -> list[ModelResponsePart]: + """Return only model response parts that are complete (i.e., not ToolCallPartDelta's). + + Returns: + A list of ModelResponsePart objects. ToolCallPartDelta objects are excluded. + """ return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] - def handle_text_delta(self, *, vendor_part_id: Hashable | None, content: str) -> ModelResponseStreamEvent | None: - # vendor_part_id=None means to use the latest part if it is a text part, otherwise make a new one + def handle_text_delta( + self, + *, + vendor_part_id: Hashable | None, + content: str, + ) -> ModelResponseStreamEvent: + """Handle incoming text content, creating or updating a TextPart in the manager as appropriate. + + When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart; + otherwise, a new TextPart is created. When a non-None ID is specified, the TextPart corresponding + to that vendor ID is either created or updated. + + Args: + vendor_part_id: The ID the vendor uses to identify this piece + of text. If None, a new part will be created unless the latest part is already + a TextPart. + content: The text content to append to the appropriate TextPart. + + Returns: + A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated. + + Raises: + UnexpectedModelBehavior: If attempting to apply text content to a part that is + not a TextPart. + """ existing_text_part_and_index: tuple[TextPart, int] | None = None + if vendor_part_id is None: + # If the vendor_part_id is None, check if the latest part is a TextPart to update if self._parts: latest_part = self._parts[-1] part_index = len(self._parts) - 1 if isinstance(latest_part, TextPart): existing_text_part_and_index = latest_part, part_index else: + # Otherwise, attempt to look up an existing TextPart by vendor_part_id part_index = self._vendor_id_to_part_index.get(vendor_part_id) if part_index is not None: existing_part = self._parts[part_index] @@ -48,6 +109,7 @@ def handle_text_delta(self, *, vendor_part_id: Hashable | None, content: str) -> existing_text_part_and_index = existing_part, part_index if existing_text_part_and_index is None: + # There is no existing text part that should be updated, so create a new one new_part_index = len(self._parts) part = TextPart(content=content) if vendor_part_id is not None: @@ -55,6 +117,7 @@ def handle_text_delta(self, *, vendor_part_id: Hashable | None, content: str) -> self._parts.append(part) return PartStartEvent(index=new_part_index, part=part) else: + # Update the existing TextPart with the new content delta existing_text_part, part_index = existing_text_part_and_index part_delta = TextPartDelta(content_delta=content) self._parts[part_index] = part_delta.apply(existing_text_part) @@ -68,10 +131,35 @@ def handle_tool_call_delta( args: str | dict[str, Any] | None, tool_call_id: str | None, ) -> ModelResponseStreamEvent | None: - # vendor_part_id=None means to use the latest part if it is a matching tool call part, otherwise make a new one + """Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`. + + Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which + point they are upgraded to `ToolCallPart`s. + + If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta) + if any. Otherwise, a new part (or delta) may be created. + + Args: + vendor_part_id: The ID the vendor uses for this tool call. + If None, the latest matching tool call may be updated. + tool_name: The name of the tool. If None, the manager does not enforce + a name match when `vendor_part_id` is None. + args: Arguments for the tool call, either as a string or a dictionary of key-value pairs. + tool_call_id: An optional string representing an identifier for this tool call. + + Returns: + - A `PartStartEvent` if a new (fully realized) ToolCallPart is created. + - A `PartDeltaEvent` if an existing part is updated. + - `None` if no new event is emitted (e.g., the part is still incomplete). + + Raises: + UnexpectedModelBehavior: If attempting to apply a tool call delta to a part that is not + a ToolCallPart or ToolCallPartDelta. + """ existing_matching_part_and_index: tuple[ToolCallPartDelta | ToolCallPart, int] | None = None + if vendor_part_id is None: - # If vendor_part_id is not provided, the tool_name must match the latest part to perform updates + # vendor_part_id is None, so check if the latest part is a matching tool call or delta to update if self._parts: latest_part = self._parts[-1] part_index = len(self._parts) - 1 @@ -87,6 +175,7 @@ def handle_tool_call_delta( ): existing_matching_part_and_index = latest_part, part_index else: + # vendor_part_id is provided, so look up the corresponding part or delta part_index = self._vendor_id_to_part_index.get(vendor_part_id) if part_index is not None: existing_part = self._parts[part_index] @@ -95,6 +184,7 @@ def handle_tool_call_delta( existing_matching_part_and_index = existing_part, part_index if existing_matching_part_and_index is None: + # No matching part/delta was found, so create a new ToolCallPartDelta (or ToolCallPart if fully formed) delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) part = delta.as_part() or delta if vendor_part_id is not None: @@ -105,16 +195,17 @@ def handle_tool_call_delta( if isinstance(part, ToolCallPart): return PartStartEvent(index=new_part_index, part=part) else: + # Update the existing part or delta with the new information existing_part, part_index = existing_matching_part_and_index delta = ToolCallPartDelta(tool_name_delta=tool_name, args_delta=args, tool_call_id=tool_call_id) updated_part = delta.apply(existing_part) self._parts[part_index] = updated_part if isinstance(updated_part, ToolCallPart): if isinstance(existing_part, ToolCallPartDelta): - # In this case, we just upgraded a delta to a full part, so emit a PartStartEvent: + # We just upgraded a delta to a full part, so emit a PartStartEvent return PartStartEvent(index=part_index, part=updated_part) else: - # In this case, we just updated an existing part, so emit a PartDeltaEvent: + # We updated an existing part, so emit a PartDeltaEvent return PartDeltaEvent(index=part_index, delta=delta) def handle_tool_call_part( @@ -125,11 +216,28 @@ def handle_tool_call_part( args: str | dict[str, Any], tool_call_id: str | None = None, ) -> ModelResponseStreamEvent: + """Immediately create or fully-overwrite a ToolCallPart with the given information. + + This does not apply a delta; it directly sets the tool call part contents. + + Args: + vendor_part_id: The vendor's ID for this tool call part. If not + None and an existing part is found, that part is overwritten. + tool_name: The name of the tool being invoked. + args: The arguments for the tool call, either as a string or a dictionary. + tool_call_id: An optional string identifier for this tool call. + + Returns: + ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part + has been added to the manager, or replaced an existing part. + """ new_part = ToolCallPart.from_raw_args(tool_name=tool_name, args=args, tool_call_id=tool_call_id) if vendor_part_id is None: + # vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list new_part_index = len(self._parts) self._parts.append(new_part) else: + # vendor_part_id is provided, so find and overwrite or create a new ToolCallPart. maybe_part_index = self._vendor_id_to_part_index.get(vendor_part_id) if maybe_part_index is not None: new_part_index = maybe_part_index diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index f84c161e..200ab7e3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -191,9 +191,7 @@ async def __anext__(self) -> ModelResponseStreamEvent: async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: if isinstance(item, str): - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) - if maybe_event is not None: - yield maybe_event + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) else: delta_tool_calls = item for dtc_index, delta_tool_call in delta_tool_calls.items(): diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 6159d737..284bf6b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -322,11 +322,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if 'text' in gemini_part: # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled # amongst the tool call deltas - maybe_event = self._parts_manager.handle_text_delta( - vendor_part_id=None, content=gemini_part['text'] - ) - if maybe_event is not None: - yield maybe_event + yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text']) elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 1a0f2109..56afd3af 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -312,9 +312,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) - if maybe_event is not None: - yield maybe_event + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) # Handle the tool calls for dtc in choice.delta.tool_calls or []: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index ea8e0ea0..08ed71ac 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -497,9 +497,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: tool_call_id=maybe_tool_call_part.tool_call_id, ) else: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) - if maybe_event is not None: - yield maybe_event + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) # Handle the explicit tool calls for index, dtc in enumerate(choice.delta.tool_calls or []): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 76243fcd..1b5a7b46 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -303,9 +303,7 @@ async def _get_events_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) - if maybe_event is not None: - yield maybe_event + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 728ae7ac..fa7431c2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -241,13 +241,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if len(words) == 1 and len(text) > 2: mid = len(text) // 2 words = [text[:mid], text[mid:]] - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='') - if maybe_event is not None: - yield maybe_event + yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='') for word in words: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) - if maybe_event is not None: - yield maybe_event + yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) else: args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict yield self._parts_manager.handle_tool_call_part( From cfa72d099c47aaa8b2bb9a061893195761dec17e Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:15:39 -0700 Subject: [PATCH 23/34] Get rid of some of the excessive is_last tracking --- pydantic_ai_slim/pydantic_ai/_utils.py | 14 +++--- pydantic_ai_slim/pydantic_ai/result.py | 61 ++++++++++++++------------ tests/test_streaming.py | 16 ++++--- tests/test_utils.py | 12 ++--- 4 files changed, 55 insertions(+), 48 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 53446afa..293d84df 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -137,7 +137,7 @@ def __repr__(self): @asynccontextmanager async def group_by_temporal( aiter: AsyncIterator[T], soft_max_interval: float | None -) -> AsyncIterator[AsyncIterable[tuple[list[T], bool]]]: +) -> AsyncIterator[AsyncIterable[list[T]]]: """Group items from an async iterable into lists based on time interval between them. Effectively debouncing the iterator. @@ -165,10 +165,9 @@ async def group_by_temporal( """ if soft_max_interval is None: - async def async_iter_groups_noop() -> AsyncIterator[tuple[list[T], bool]]: + async def async_iter_groups_noop() -> AsyncIterator[list[T]]: async for item in aiter: - yield [item], False - yield [], True + yield [item] yield async_iter_groups_noop() return @@ -176,7 +175,7 @@ async def async_iter_groups_noop() -> AsyncIterator[tuple[list[T], bool]]: # we might wait for the next item more than once, so we store the task to await next time task: asyncio.Task[T] | None = None - async def async_iter_groups() -> AsyncIterator[tuple[list[T], bool]]: + async def async_iter_groups() -> AsyncIterator[list[T]]: nonlocal task assert soft_max_interval is not None and soft_max_interval >= 0, 'soft_max_interval must be a positive number' @@ -206,7 +205,8 @@ async def async_iter_groups() -> AsyncIterator[tuple[list[T], bool]]: item = done.pop().result() except StopAsyncIteration: # if the task raised StopAsyncIteration, we're done iterating - yield buffer, True + if buffer: + yield buffer task = None break else: @@ -218,7 +218,7 @@ async def async_iter_groups() -> AsyncIterator[tuple[list[T], bool]]: group_start_time = time.monotonic() elif buffer: # otherwise if the task timeout expired and we have items in the buffer, yield the buffer - yield buffer, False + yield buffer # clear the buffer and reset the group start time ready for the next group buffer = [] group_start_time = None diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index fec69e64..9bc8f643 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,7 +1,6 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections import defaultdict from collections.abc import AsyncIterator, Awaitable, Callable from copy import deepcopy from dataclasses import dataclass, field @@ -223,41 +222,44 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = self._stream_response, self._usage_limits, self.usage ) - async def _stream_text_deltas() -> AsyncIterator[tuple[str, int]]: + async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: # if the response currently has any parts with content, yield those before streaming msg = self._stream_response.get() for i, part in enumerate(msg.parts): if isinstance(part, TextPart) and part.content: yield part.content, i - async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for events, _is_final in group_iter: - for event in events: - if ( - isinstance(event, _messages.PartStartEvent) - and isinstance(event.part, _messages.TextPart) - and event.part.content - ): - yield event.part.content, event.index - elif ( - isinstance(event, _messages.PartDeltaEvent) - and isinstance(event.delta, _messages.TextPartDelta) - and event.delta.content_delta - ): - yield event.delta.content_delta, event.index + async for event in usage_checking_stream: + if ( + isinstance(event, _messages.PartStartEvent) + and isinstance(event.part, _messages.TextPart) + and event.part.content + ): + yield event.part.content, event.index + elif ( + isinstance(event, _messages.PartDeltaEvent) + and isinstance(event.delta, _messages.TextPartDelta) + and event.delta.content_delta + ): + yield event.delta.content_delta, event.index + + async def _stream_text_deltas() -> AsyncIterator[str]: + async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: + async for items in group_iter: + yield ''.join([content for content, _ in items]) with _logfire.span('response stream text') as lf_span: if delta: - async for text, _ in _stream_text_deltas(): + async for text in _stream_text_deltas(): yield text else: # a quick benchmark shows it's faster to build up a string with concat when we're # yielding at each step - chunks: dict[int, str] = defaultdict(str) + deltas: list[str] = [] combined_validated_text = '' - async for text, index in _stream_text_deltas(): - chunks[index] += text - combined_text = ''.join([chunks[k] for k in sorted(chunks)]) + async for text in _stream_text_deltas(): + deltas.append(text) + combined_text = ''.join(deltas) combined_validated_text = await self._validate_text_result(combined_text) yield combined_validated_text @@ -290,13 +292,14 @@ async def stream_structured( break async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for _events, is_final in group_iter: - msg = self._stream_response.get(final=is_final) - yield msg, is_final - if is_final: - # TODO: Should this now be `final_response` instead of `structured_response`? - lf_span.set_attribute('structured_response', msg) - await self._marked_completed(msg) + async for _events in group_iter: + msg = self._stream_response.get() + yield msg, False + msg = self._stream_response.get(final=True) + yield msg, True + # TODO: Should this now be `final_response` instead of `structured_response`? + lf_span.set_attribute('structured_response', msg) + await self._marked_completed(msg) async def get_data(self) -> ResultData: """Stream the whole response, validate and return it.""" diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 1c2d3b49..ed267b2f 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -137,13 +137,20 @@ async def test_streamed_text_stream(): async with agent.run_stream('Hello') as result: # typehint to test (via static typing) that the stream type is correctly inferred - chunks: list[str] = [c async for c in result.stream()] - # one chunk due to group_by_temporal + chunks: list[str] = [c async for c in result.stream_text()] + # one chunk with `stream_text()` due to group_by_temporal assert chunks == snapshot(['The cat sat on the mat.']) assert result.is_complete async with agent.run_stream('Hello') as result: - assert [c async for c in result.stream(debounce_by=None)] == snapshot( + # typehint to test (via static typing) that the stream type is correctly inferred + chunks: list[str] = [c async for c in result.stream()] + # two chunks with `stream()` due to not-final vs. final + assert chunks == snapshot(['The cat sat on the mat.', 'The cat sat on the mat.']) + assert result.is_complete + + async with agent.run_stream('Hello') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( [ 'The ', 'The cat ', @@ -151,9 +158,6 @@ async def test_streamed_text_stream(): 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.', - # This last value is repeated due to the debounce_by=None combined with the need to emit a final empty - # chunk to signal the end of the stream (which is used to determine whether to allow partial JSON) - 'The cat sat on the mat.', ] ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 96c79203..c2f199e1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -15,11 +15,11 @@ @pytest.mark.parametrize( 'interval,expected', [ - (None, snapshot([([1], False), ([2], False), ([3], False), ([], True)])), - (0, snapshot([([1], False), ([2], False), ([3], False), ([], True)])), - (0.02, snapshot([([1], False), ([2], False), ([3], False), ([], True)])), - (0.04, snapshot([([1, 2], False), ([3], True)])), - (0.1, snapshot([([1, 2, 3], True)])), + (None, snapshot([([1]), ([2]), ([3])])), + (0, snapshot([([1]), ([2]), ([3])])), + (0.02, snapshot([([1]), ([2]), ([3])])), + (0.04, snapshot([([1, 2]), ([3])])), + (0.1, snapshot([([1, 2, 3])])), ], ) async def test_group_by_temporal(interval: float | None, expected: list[list[int]]): @@ -32,7 +32,7 @@ async def yield_groups() -> AsyncIterator[int]: await asyncio.sleep(0.02) async with group_by_temporal(yield_groups(), soft_max_interval=interval) as groups_iter: - groups: list[tuple[list[int], bool]] = [g async for g in groups_iter] + groups: list[list[int]] = [g async for g in groups_iter] assert groups == expected From cd0924716b1f5e821a251b238d5a26c3a97c9307 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:19:49 -0700 Subject: [PATCH 24/34] Update a comment --- pydantic_ai_slim/pydantic_ai/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 293d84df..7e48e47b 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -160,8 +160,7 @@ async def group_by_temporal( as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed Returns: - A context manager usable as an async iterable of pairs of lists of items from the input async iterable, - and a boolean indicating whether the item was final coming out of the iterator. + A context manager usable as an async iterable of lists of items produced by the input async iterable. """ if soft_max_interval is None: From 43e0b0f15efc762e865f54ca78b4b04e3fbccff0 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:44:50 -0700 Subject: [PATCH 25/34] Move much of the aiter implementation up to StreamedResponse --- pydantic_ai_slim/pydantic_ai/_utils.py | 9 +++---- .../pydantic_ai/models/__init__.py | 13 +++++++--- .../pydantic_ai/models/function.py | 6 ----- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 ----- pydantic_ai_slim/pydantic_ai/models/groq.py | 6 ----- .../pydantic_ai/models/mistral.py | 16 +------------ pydantic_ai_slim/pydantic_ai/models/openai.py | 12 +++------- pydantic_ai_slim/pydantic_ai/models/test.py | 24 ++++--------------- pydantic_ai_slim/pydantic_ai/result.py | 6 ++--- 9 files changed, 26 insertions(+), 72 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 7e48e47b..74b055e6 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -136,7 +136,7 @@ def __repr__(self): @asynccontextmanager async def group_by_temporal( - aiter: AsyncIterator[T], soft_max_interval: float | None + aiterable: AsyncIterable[T], soft_max_interval: float | None ) -> AsyncIterator[AsyncIterable[list[T]]]: """Group items from an async iterable into lists based on time interval between them. @@ -154,7 +154,7 @@ async def group_by_temporal( ``` Args: - aiter: The async iterable to group. + aiterable: The async iterable to group. soft_max_interval: Maximum interval over which to group items, this should avoid a trickle of items causing a group to never be yielded. It's a soft max in the sense that once we're over this time, we yield items as soon as `aiter.__anext__()` returns. If `None`, no grouping/debouncing is performed @@ -165,7 +165,7 @@ async def group_by_temporal( if soft_max_interval is None: async def async_iter_groups_noop() -> AsyncIterator[list[T]]: - async for item in aiter: + async for item in aiterable: yield [item] yield async_iter_groups_noop() @@ -181,6 +181,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]: buffer: list[T] = [] group_start_time = time.monotonic() + aiterator = aiterable.__aiter__() while True: if group_start_time is None: # group hasn't started, we just wait for the maximum interval @@ -193,7 +194,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]: if task is None: # aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects # so far, this doesn't seem to be a problem - task = asyncio.create_task(aiter.__anext__()) # pyright: ignore[reportArgumentType] + task = asyncio.create_task(aiterator.__anext__()) # pyright: ignore[reportArgumentType] # we use asyncio.wait to avoid cancelling the coroutine if it's not done done, _ = await asyncio.wait((task,), timeout=wait_time) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index d4205600..d223ad8c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass, field from datetime import datetime from functools import cache from typing import TYPE_CHECKING, Literal @@ -138,17 +139,23 @@ async def request_stream( yield # pragma: no cover +@dataclass class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" + _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) + def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: """Stream the response as an async iterable of (optional) `ModelResponseStreamEvent`s.""" - return self + if self._event_iterator is None: + self._event_iterator = self._get_event_iterator() + return self._event_iterator @abstractmethod - async def __anext__(self) -> ModelResponseStreamEvent: - """Process the next chunk of the response.""" + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: raise NotImplementedError() + # noinspection PyUnreachableCode + yield @abstractmethod def get(self, *, final: bool = False) -> ModelResponse: diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 200ab7e3..0e9c931c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -181,12 +181,6 @@ class FunctionStreamedResponse(StreamedResponse): _timestamp: datetime = field(default_factory=_utils.now_utc) _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - - async def __anext__(self) -> ModelResponseStreamEvent: - if self._event_iterator is None: - self._event_iterator = self._get_event_iterator() - return await self._event_iterator.__anext__() async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 284bf6b8..5527e7b0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -307,12 +307,6 @@ class GeminiStreamedResponse(StreamedResponse): _usage: result.Usage = field(default_factory=result.Usage, init=False) _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - - async def __anext__(self) -> ModelResponseStreamEvent: - if self._event_iterator is None: - self._event_iterator = self._get_event_iterator() - return await self._event_iterator.__anext__() async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for gemini_response in self._get_gemini_responses(): diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 56afd3af..f0439888 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -293,12 +293,6 @@ class GroqStreamedResponse(StreamedResponse): _usage: result.Usage = field(default_factory=result.Usage, init=False) _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - - async def __anext__(self) -> ModelResponseStreamEvent: - if self._event_iterator is None: - self._event_iterator = self._get_event_iterator() - return await self._event_iterator.__anext__() async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 08ed71ac..9f20cf83 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -453,23 +453,9 @@ class MistralStreamedResponse(StreamedResponse): _result_tools: dict[str, ToolDefinition] _usage: Usage = field(default_factory=Usage, init=False) - - # _function_tools: dict[str, MistralToolCall] = field(default_factory=dict, init=False) - _delta_content: str = '' - # _delta_tool_calls: dict[MistralToolCallId, MistralToolCall] = field(default_factory=dict, init=False) - # _result_part_index: int | None = field(default=None, init=False) - # _content_part_index: int | None = field(default=None, init=False) - # _tool_call_id_to_part_index: dict[MistralToolCallId, int] = field(default_factory=dict, init=False) - # _parts: dict[int, ModelResponsePart] = field(default_factory=dict, init=False) + _delta_content: str = field(default='', init=False) _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - - async def __anext__(self) -> ModelResponseStreamEvent: - if self._event_iterator is None: - self._event_iterator = self._get_event_iterator() - - return await self._event_iterator.__anext__() async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: chunk: MistralCompletionEvent diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 1b5a7b46..caed4574 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -12,7 +12,7 @@ from .. import UnexpectedModelBehavior, _utils, result from .._parts_manager import ModelResponsePartsManager -from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc +from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, ModelRequest, @@ -284,14 +284,8 @@ class OpenAIStreamedResponse(StreamedResponse): _usage: result.Usage = field(default_factory=result.Usage, init=False) _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - async def __anext__(self) -> ModelResponseStreamEvent: - if self._event_iterator is None: - self._event_iterator = self._get_events_iterator() - return await self._event_iterator.__anext__() - - async def _get_events_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: self._usage += _map_usage(chunk) @@ -323,7 +317,7 @@ def usage(self) -> Usage: return self._usage def timestamp(self) -> datetime: - return self._timestamp or _now_utc() + return self._timestamp def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index fa7431c2..53471fc6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -19,11 +19,8 @@ ModelResponse, ModelResponsePart, ModelResponseStreamEvent, - PartDeltaEvent, - PartStartEvent, RetryPromptPart, TextPart, - TextPartDelta, ToolCallPart, ToolReturnPart, ) @@ -219,17 +216,6 @@ class TestStreamedResponse(StreamedResponse): _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) - - async def __anext__(self) -> ModelResponseStreamEvent: - if self._event_iterator is None: - self._event_iterator = self._get_event_iterator() - - next_event = await self._event_iterator.__anext__() - - update = _estimate_event_usage(next_event) - self._usage += update - return next_event async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for i, part in enumerate(self._structured_response.parts): @@ -241,8 +227,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if len(words) == 1 and len(text) > 2: mid = len(text) // 2 words = [text[:mid], text[mid:]] + self._usage += _get_string_usage('') yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='') for word in words: + self._usage += _get_string_usage(word) yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) else: args = part.args.args_json if isinstance(part.args, ArgsJson) else part.args.args_dict @@ -417,10 +405,6 @@ def _char(self) -> str: return s -def _estimate_event_usage(event: ModelResponseStreamEvent) -> Usage: - response_tokens = 0 - if isinstance(event, PartStartEvent) and isinstance(event.part, TextPart): - response_tokens = _estimate_string_usage(event.part.content) - elif isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta): - response_tokens = _estimate_string_usage(event.delta.content_delta) +def _get_string_usage(text: str) -> Usage: + response_tokens = _estimate_string_usage(text) return Usage(response_tokens=response_tokens, total_tokens=response_tokens) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 9bc8f643..f5ac0a54 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime @@ -369,10 +369,10 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: def _get_usage_checking_stream_response( - stream_response: AsyncIterator[ModelResponseStreamEvent], + stream_response: AsyncIterable[ModelResponseStreamEvent], limits: UsageLimits | None, get_usage: Callable[[], Usage], -) -> AsyncIterator[ModelResponseStreamEvent]: +) -> AsyncIterable[ModelResponseStreamEvent]: if limits is not None and limits.has_token_limits(): async def _usage_checking_iterator(): From 94e8482fe5b7458f735fc19b0a288f1d23bd74d6 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:51:44 -0700 Subject: [PATCH 26/34] Move _parts_manager to StreamedResponse --- pydantic_ai_slim/pydantic_ai/_parts_manager.py | 2 +- pydantic_ai_slim/pydantic_ai/models/__init__.py | 2 ++ pydantic_ai_slim/pydantic_ai/models/function.py | 3 --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 3 --- pydantic_ai_slim/pydantic_ai/models/groq.py | 3 --- pydantic_ai_slim/pydantic_ai/models/mistral.py | 3 --- pydantic_ai_slim/pydantic_ai/models/openai.py | 3 --- pydantic_ai_slim/pydantic_ai/models/test.py | 3 --- 8 files changed, 3 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 201c41af..c24dab70 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field from typing import Any, Union -from pydantic_ai import UnexpectedModelBehavior +from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( ModelResponsePart, ModelResponseStreamEvent, diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index d223ad8c..0ac13f06 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -16,6 +16,7 @@ import httpx +from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent from ..settings import ModelSettings @@ -143,6 +144,7 @@ async def request_stream( class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" + _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 0e9c931c..68726f99 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -12,7 +12,6 @@ from typing_extensions import TypeAlias, assert_never, overload from .. import _utils, result -from .._parts_manager import ModelResponsePartsManager from .._utils import PeekableAsyncStream from ..messages import ( ModelMessage, @@ -180,8 +179,6 @@ class FunctionStreamedResponse(StreamedResponse): _iter: AsyncIterator[str | DeltaToolCalls] _timestamp: datetime = field(default_factory=_utils.now_utc) - _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: if isinstance(item, str): diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 5527e7b0..ab17cd9b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -15,7 +15,6 @@ from typing_extensions import NotRequired, TypedDict, assert_never from .. import UnexpectedModelBehavior, _utils, exceptions, result -from .._parts_manager import ModelResponsePartsManager from ..messages import ( ModelMessage, ModelRequest, @@ -306,8 +305,6 @@ class GeminiStreamedResponse(StreamedResponse): _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) _usage: result.Usage = field(default_factory=result.Usage, init=False) - _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for gemini_response in self._get_gemini_responses(): candidate = gemini_response['candidates'][0] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index f0439888..db25eb82 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -11,7 +11,6 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils, result -from .._parts_manager import ModelResponsePartsManager from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, @@ -292,8 +291,6 @@ class GroqStreamedResponse(StreamedResponse): _timestamp: datetime _usage: result.Usage = field(default_factory=result.Usage, init=False) - _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: self._usage += _map_usage(chunk) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 9f20cf83..9438a363 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -13,7 +13,6 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils -from .._parts_manager import ModelResponsePartsManager from .._utils import now_utc as _now_utc from ..messages import ( ArgsJson, @@ -455,8 +454,6 @@ class MistralStreamedResponse(StreamedResponse): _usage: Usage = field(default_factory=Usage, init=False) _delta_content: str = field(default='', init=False) - _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: chunk: MistralCompletionEvent async for chunk in self._response: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index caed4574..029b7812 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -11,7 +11,6 @@ from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils, result -from .._parts_manager import ModelResponsePartsManager from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, @@ -283,8 +282,6 @@ class OpenAIStreamedResponse(StreamedResponse): _usage: result.Usage = field(default_factory=result.Usage, init=False) - _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: self._usage += _map_usage(chunk) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 53471fc6..5aac3bc7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -11,7 +11,6 @@ import pydantic_core from .. import _utils -from .._parts_manager import ModelResponsePartsManager from ..messages import ( ArgsJson, ModelMessage, @@ -215,8 +214,6 @@ class TestStreamedResponse(StreamedResponse): _usage: Usage _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for i, part in enumerate(self._structured_response.parts): if isinstance(part, TextPart): From 38c1e299ee663c85bd32ae3d32e8ca79fb0e01ff Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:54:56 -0700 Subject: [PATCH 27/34] Move get() up to StreamedResponse --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 10 +--------- pydantic_ai_slim/pydantic_ai/models/function.py | 3 --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 5 ----- pydantic_ai_slim/pydantic_ai/models/groq.py | 4 ---- pydantic_ai_slim/pydantic_ai/models/mistral.py | 4 ---- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 ---- pydantic_ai_slim/pydantic_ai/models/test.py | 4 ---- 7 files changed, 1 insertion(+), 33 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 0ac13f06..7d44a325 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -159,16 +159,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noinspection PyUnreachableCode yield - @abstractmethod def get(self, *, final: bool = False) -> ModelResponse: - """Get the `ModelResponse` at this point. - - The `ModelResponse` may or may not be complete, depending on whether the stream is finished. - - Args: - final: If True, this is the final call, after iteration is complete, the response should be fully validated. - """ - raise NotImplementedError() + return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp()) @abstractmethod def usage(self) -> Usage: diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 68726f99..b1f973ab 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -195,9 +195,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if maybe_event is not None: yield maybe_event - def get(self, *, final: bool = False) -> ModelResponse: - return ModelResponse(self._parts_manager.get_parts(), timestamp=self._timestamp) - def usage(self) -> result.Usage: return _estimate_usage([self.get()]) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index ab17cd9b..ca9a881c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -362,11 +362,6 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: self._usage += _metadata_as_usage(r) yield r - def get(self, *, final: bool = False) -> ModelResponse: - """Get the `ModelResponse` at this point.""" - parts = self._parts_manager.get_parts() - return ModelResponse(parts=parts, timestamp=self._timestamp) - def usage(self) -> result.Usage: return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index db25eb82..392c6da1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -316,10 +316,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if maybe_event is not None: yield maybe_event - def get(self, *, final: bool = False) -> ModelResponse: - parts = self._parts_manager.get_parts() - return ModelResponse(parts=parts, timestamp=self._timestamp) - def usage(self) -> Usage: return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 9438a363..c1c1bdd9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -489,10 +489,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id ) - def get(self, *, final: bool = False) -> ModelResponse: - parts = self._parts_manager.get_parts() - return ModelResponse(parts=parts, timestamp=self._timestamp) - def usage(self) -> Usage: return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 029b7812..8e6e361c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -306,10 +306,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if maybe_event is not None: yield maybe_event - def get(self, *, final: bool = False) -> ModelResponse: - parts = self._parts_manager.get_parts() - return ModelResponse(parts=parts, timestamp=self._timestamp) - def usage(self) -> Usage: return self._usage diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 5aac3bc7..0f057379 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -235,10 +235,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id ) - def get(self, *, final: bool = False) -> ModelResponse: - parts = self._parts_manager.get_parts() - return ModelResponse(parts, timestamp=self._timestamp) - def usage(self) -> Usage: return self._usage From 9b287104b06358e5a100ccc282f11124fb949fad Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:58:33 -0700 Subject: [PATCH 28/34] Remove the unused 'final' argument to get --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 2 +- pydantic_ai_slim/pydantic_ai/result.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 7d44a325..34415c26 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -159,7 +159,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noinspection PyUnreachableCode yield - def get(self, *, final: bool = False) -> ModelResponse: + def get(self) -> ModelResponse: return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp()) @abstractmethod diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index f5ac0a54..79cdc681 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -295,7 +295,7 @@ async def stream_structured( async for _events in group_iter: msg = self._stream_response.get() yield msg, False - msg = self._stream_response.get(final=True) + msg = self._stream_response.get() yield msg, True # TODO: Should this now be `final_response` instead of `structured_response`? lf_span.set_attribute('structured_response', msg) @@ -309,7 +309,7 @@ async def get_data(self) -> ResultData: async for _ in usage_checking_stream: pass - message = self._stream_response.get(final=True) + message = self._stream_response.get() await self._marked_completed(message) return await self.validate_structured_result(message) From 277ba849351458d7b0ecdfe1a75f9461ed8be29c Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:22:18 -0700 Subject: [PATCH 29/34] Move usage() up to StreamedResponse --- docs/results.md | 2 +- .../pydantic_ai/models/__init__.py | 10 ++---- .../pydantic_ai/models/anthropic.py | 26 +++++++------- .../pydantic_ai/models/function.py | 33 ++++++++++-------- pydantic_ai_slim/pydantic_ai/models/gemini.py | 14 +++----- pydantic_ai_slim/pydantic_ai/models/groq.py | 29 +++++++--------- .../pydantic_ai/models/mistral.py | 4 --- pydantic_ai_slim/pydantic_ai/models/openai.py | 34 ++++++++----------- pydantic_ai_slim/pydantic_ai/models/test.py | 22 ++++++------ 9 files changed, 78 insertions(+), 96 deletions(-) diff --git a/docs/results.md b/docs/results.md index 802619a8..d1898cef 100644 --- a/docs/results.md +++ b/docs/results.md @@ -1,5 +1,5 @@ Results are the final values returned from [running an agent](agents.md#running-agents). -The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.result.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) +The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) Both `RunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 34415c26..3c2b77ee 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -20,9 +20,9 @@ from ..exceptions import UserError from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent from ..settings import ModelSettings +from ..usage import Usage if TYPE_CHECKING: - from ..result import Usage from ..tools import ToolDefinition @@ -144,6 +144,7 @@ async def request_stream( class StreamedResponse(ABC): """Streamed response from an LLM when calling a tool.""" + _usage: Usage = field(default_factory=Usage, init=False) _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False) _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) @@ -162,13 +163,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: def get(self) -> ModelResponse: return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp()) - @abstractmethod def usage(self) -> Usage: - """Get the usage of the request. - - NOTE: this won't return the full usage until the stream is finished. - """ - raise NotImplementedError() + return self._usage @abstractmethod def timestamp(self) -> datetime: diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index d1629f20..d62f6ccb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -8,7 +8,7 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import result +from .. import usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ArgsDict, @@ -158,7 +158,7 @@ class AnthropicAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Usage]: + ) -> tuple[ModelResponse, usage.Usage]: response = await self._messages_create(messages, False, model_settings) return self._process_response(response), _map_usage(response) @@ -315,30 +315,30 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam: ) -def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage: +def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage: if isinstance(message, AnthropicMessage): - usage = message.usage + response_usage = message.usage else: if isinstance(message, RawMessageStartEvent): - usage = message.message.usage + response_usage = message.message.usage elif isinstance(message, RawMessageDeltaEvent): - usage = message.usage + response_usage = message.usage else: # No usage information provided in: # - RawMessageStopEvent # - RawContentBlockStartEvent # - RawContentBlockDeltaEvent # - RawContentBlockStopEvent - usage = None + response_usage = None - if usage is None: - return result.Usage() + if response_usage is None: + return usage.Usage() - request_tokens = getattr(usage, 'input_tokens', None) + request_tokens = getattr(response_usage, 'input_tokens', None) - return result.Usage( + return usage.Usage( # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr request_tokens=request_tokens, - response_tokens=usage.output_tokens, - total_tokens=(request_tokens or 0) + usage.output_tokens, + response_tokens=response_usage.output_tokens, + total_tokens=(request_tokens or 0) + response_usage.output_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index b1f973ab..6d66f18c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -11,7 +11,7 @@ from typing_extensions import TypeAlias, assert_never, overload -from .. import _utils, result +from .. import _utils, usage from .._utils import PeekableAsyncStream from ..messages import ( ModelMessage, @@ -143,7 +143,7 @@ class FunctionAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Usage]: + ) -> tuple[ModelResponse, usage.Usage]: agent_info = replace(self.agent_info, model_settings=model_settings) assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' @@ -179,13 +179,21 @@ class FunctionStreamedResponse(StreamedResponse): _iter: AsyncIterator[str | DeltaToolCalls] _timestamp: datetime = field(default_factory=_utils.now_utc) + def __post_init__(self): + self._usage += _estimate_usage([]) + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for item in self._iter: if isinstance(item, str): + response_tokens = _estimate_string_tokens(item) + self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) else: delta_tool_calls = item for dtc_index, delta_tool_call in delta_tool_calls.items(): + if delta_tool_call.json_args: + response_tokens = _estimate_string_tokens(delta_tool_call.json_args) + self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc_index, tool_name=delta_tool_call.name, @@ -195,14 +203,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if maybe_event is not None: yield maybe_event - def usage(self) -> result.Usage: - return _estimate_usage([self.get()]) - def timestamp(self) -> datetime: return self._timestamp -def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage: +def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: """Very rough guesstimate of the token usage associated with a series of messages. This is designed to be used solely to give plausible numbers for testing! @@ -214,30 +219,30 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage: if isinstance(message, ModelRequest): for part in message.parts: if isinstance(part, (SystemPromptPart, UserPromptPart)): - request_tokens += _estimate_string_usage(part.content) + request_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ToolReturnPart): - request_tokens += _estimate_string_usage(part.model_response_str()) + request_tokens += _estimate_string_tokens(part.model_response_str()) elif isinstance(part, RetryPromptPart): - request_tokens += _estimate_string_usage(part.model_response()) + request_tokens += _estimate_string_tokens(part.model_response()) else: assert_never(part) elif isinstance(message, ModelResponse): for part in message.parts: if isinstance(part, TextPart): - response_tokens += _estimate_string_usage(part.content) + response_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ToolCallPart): call = part - response_tokens += 1 + _estimate_string_usage(call.args_as_json_str()) + response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str()) else: assert_never(part) else: assert_never(message) - return result.Usage( + return usage.Usage( request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens ) -def _estimate_string_usage(content: str) -> int: +def _estimate_string_tokens(content: str) -> int: if not content: return 0 - return len(re.split(r'[\s",.:]+', content)) + return len(re.split(r'[\s",.:]+', content.strip())) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index ca9a881c..298923d8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -14,7 +14,7 @@ from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse from typing_extensions import NotRequired, TypedDict, assert_never -from .. import UnexpectedModelBehavior, _utils, exceptions, result +from .. import UnexpectedModelBehavior, _utils, exceptions, usage from ..messages import ( ModelMessage, ModelRequest, @@ -170,7 +170,7 @@ def __init__( async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Usage]: + ) -> tuple[ModelResponse, usage.Usage]: async with self._make_request(messages, False, model_settings) as http_response: response = _gemini_response_ta.validate_json(await http_response.aread()) return self._process_response(response), _metadata_as_usage(response) @@ -303,7 +303,6 @@ class GeminiStreamedResponse(StreamedResponse): _content: bytearray _stream: AsyncIterator[bytes] _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - _usage: result.Usage = field(default_factory=result.Usage, init=False) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for gemini_response in self._get_gemini_responses(): @@ -362,9 +361,6 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: self._usage += _metadata_as_usage(r) yield r - def usage(self) -> result.Usage: - return self._usage - def timestamp(self) -> datetime: return self._timestamp @@ -588,14 +584,14 @@ class _GeminiUsageMetaData(TypedDict, total=False): cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] -def _metadata_as_usage(response: _GeminiResponse) -> result.Usage: +def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: metadata = response.get('usage_metadata') if metadata is None: - return result.Usage() + return usage.Usage() details: dict[str, int] = {} if cached_content_token_count := metadata.get('cached_content_token_count'): details['cached_content_token_count'] = cached_content_token_count - return result.Usage( + return usage.Usage( request_tokens=metadata.get('prompt_token_count', 0), response_tokens=metadata.get('candidates_token_count', 0), total_tokens=metadata.get('total_token_count', 0), diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 392c6da1..2f770737 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import UnexpectedModelBehavior, _utils, result +from .. import UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, @@ -25,7 +25,6 @@ ToolReturnPart, UserPromptPart, ) -from ..result import Usage from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -155,7 +154,7 @@ class GroqAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Usage]: + ) -> tuple[ModelResponse, usage.Usage]: response = await self._completions_create(messages, False, model_settings) return self._process_response(response), _map_usage(response) @@ -289,7 +288,6 @@ class GroqStreamedResponse(StreamedResponse): _response: AsyncIterable[ChatCompletionChunk] _timestamp: datetime - _usage: result.Usage = field(default_factory=result.Usage, init=False) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: @@ -316,9 +314,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if maybe_event is not None: yield maybe_event - def usage(self) -> Usage: - return self._usage - def timestamp(self) -> datetime: return self._timestamp @@ -331,18 +326,18 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: ) -def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage: - usage = None +def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage: + response_usage = None if isinstance(completion, ChatCompletion): - usage = completion.usage + response_usage = completion.usage elif completion.x_groq is not None: - usage = completion.x_groq.usage + response_usage = completion.x_groq.usage - if usage is None: - return result.Usage() + if response_usage is None: + return usage.Usage() - return result.Usage( - request_tokens=usage.prompt_tokens, - response_tokens=usage.completion_tokens, - total_tokens=usage.total_tokens, + return usage.Usage( + request_tokens=response_usage.prompt_tokens, + response_tokens=response_usage.completion_tokens, + total_tokens=response_usage.total_tokens, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index c1c1bdd9..9907c949 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -451,7 +451,6 @@ class MistralStreamedResponse(StreamedResponse): _timestamp: datetime _result_tools: dict[str, ToolDefinition] - _usage: Usage = field(default_factory=Usage, init=False) _delta_content: str = field(default='', init=False) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: @@ -489,9 +488,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id ) - def usage(self) -> Usage: - return self._usage - def timestamp(self) -> datetime: return self._timestamp diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 8e6e361c..83143aaf 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -10,7 +10,7 @@ from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import UnexpectedModelBehavior, _utils, result +from .. import UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ModelMessage, @@ -25,7 +25,6 @@ ToolReturnPart, UserPromptPart, ) -from ..result import Usage from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -144,7 +143,7 @@ class OpenAIAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Usage]: + ) -> tuple[ModelResponse, usage.Usage]: response = await self._completions_create(messages, False, model_settings) return self._process_response(response), _map_usage(response) @@ -280,8 +279,6 @@ class OpenAIStreamedResponse(StreamedResponse): _response: AsyncIterable[ChatCompletionChunk] _timestamp: datetime - _usage: result.Usage = field(default_factory=result.Usage, init=False) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: self._usage += _map_usage(chunk) @@ -306,9 +303,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if maybe_event is not None: yield maybe_event - def usage(self) -> Usage: - return self._usage - def timestamp(self) -> datetime: return self._timestamp @@ -321,19 +315,19 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: ) -def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage: - usage = response.usage - if usage is None: - return result.Usage() +def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usage: + response_usage = response.usage + if response_usage is None: + return usage.Usage() else: details: dict[str, int] = {} - if usage.completion_tokens_details is not None: - details.update(usage.completion_tokens_details.model_dump(exclude_none=True)) - if usage.prompt_tokens_details is not None: - details.update(usage.prompt_tokens_details.model_dump(exclude_none=True)) - return result.Usage( - request_tokens=usage.prompt_tokens, - response_tokens=usage.completion_tokens, - total_tokens=usage.total_tokens, + if response_usage.completion_tokens_details is not None: + details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True)) + if response_usage.prompt_tokens_details is not None: + details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True)) + return usage.Usage( + request_tokens=response_usage.prompt_tokens, + response_tokens=response_usage.completion_tokens, + total_tokens=response_usage.total_tokens, details=details, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 0f057379..fe03de6a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -2,9 +2,9 @@ import re import string -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable from contextlib import asynccontextmanager -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field from datetime import date, datetime, timedelta from typing import Any, Literal @@ -31,7 +31,7 @@ Model, StreamedResponse, ) -from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage] +from .function import _estimate_string_tokens, _estimate_usage # pyright: ignore[reportPrivateUsage] @dataclass @@ -141,9 +141,8 @@ async def request( async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None ) -> AsyncIterator[StreamedResponse]: - msg = self._request(messages, model_settings) - usage = _estimate_usage(messages) - yield TestStreamedResponse(msg, usage) + model_response = self._request(messages, model_settings) + yield TestStreamedResponse(model_response, messages) def gen_tool_args(self, tool_def: ToolDefinition) -> Any: return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() @@ -211,9 +210,13 @@ class TestStreamedResponse(StreamedResponse): """A structured response that streams test data.""" _structured_response: ModelResponse - _usage: Usage + _messages: InitVar[Iterable[ModelMessage]] + _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) + def __post_init__(self, _messages: Iterable[ModelMessage]): + self._usage = _estimate_usage(_messages) + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for i, part in enumerate(self._structured_response.parts): if isinstance(part, TextPart): @@ -235,9 +238,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=i, tool_name=part.tool_name, args=args, tool_call_id=part.tool_call_id ) - def usage(self) -> Usage: - return self._usage - def timestamp(self) -> datetime: return self._timestamp @@ -399,5 +399,5 @@ def _char(self) -> str: def _get_string_usage(text: str) -> Usage: - response_tokens = _estimate_string_usage(text) + response_tokens = _estimate_string_tokens(text) return Usage(response_tokens=response_tokens, total_tokens=response_tokens) From 7f71db3680b2da48b6ebdac77691fb2d1ebcc442 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:46:10 -0700 Subject: [PATCH 30/34] Make MockAsyncStream generic --- tests/models/mock_async_stream.py | 26 ++++++++++++++++++++++ tests/models/test_groq.py | 29 ++++++------------------ tests/models/test_mistral.py | 37 ++++++++++++------------------- tests/models/test_openai.py | 29 ++++++------------------ 4 files changed, 54 insertions(+), 67 deletions(-) create mode 100644 tests/models/mock_async_stream.py diff --git a/tests/models/mock_async_stream.py b/tests/models/mock_async_stream.py new file mode 100644 index 00000000..03982dd8 --- /dev/null +++ b/tests/models/mock_async_stream.py @@ -0,0 +1,26 @@ +from __future__ import annotations as _annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +from pydantic_ai import _utils + +T = TypeVar('T') + + +@dataclass +class MockAsyncStream(Generic[T]): + _iter: Iterator[T] + + async def __anext__(self) -> T: + return _utils.sync_anext(self._iter) + + def __aiter__(self) -> MockAsyncStream[T]: + return self + + async def __aenter__(self): + return self + + async def __aexit__(self, *_args: Any): + pass diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 6f3522d6..46dea3a2 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import json -from collections.abc import Iterator, Sequence +from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timezone from functools import cached_property @@ -11,7 +11,7 @@ from inline_snapshot import snapshot from typing_extensions import TypedDict -from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, _utils +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ( ArgsJson, ModelRequest, @@ -25,6 +25,7 @@ from pydantic_ai.result import Usage from ..conftest import IsNow, try_import +from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: from groq import AsyncGroq @@ -54,23 +55,6 @@ def test_init(): assert m.name() == 'groq:llama-3.1-70b-versatile' -@dataclass -class MockAsyncStream: - _iter: Iterator[chat.ChatCompletionChunk] - - async def __anext__(self) -> chat.ChatCompletionChunk: - return _utils.sync_anext(self._iter) - - def __aiter__(self) -> MockAsyncStream: - return self - - async def __aenter__(self): - return self - - async def __aexit__(self, *_args: Any): - pass - - @dataclass class MockGroq: completions: chat.ChatCompletion | list[chat.ChatCompletion] | None = None @@ -94,14 +78,15 @@ def create_mock_stream( async def chat_completions_create( self, *_args: Any, stream: bool = False, **_kwargs: Any - ) -> chat.ChatCompletion | MockAsyncStream: + ) -> chat.ChatCompletion | MockAsyncStream[chat.ChatCompletionChunk]: if stream: assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' # noinspection PyUnresolvedReferences if isinstance(self.stream[0], list): # pragma: no cover - response = MockAsyncStream(iter(self.stream[self.index])) # type: ignore + indexed_stream = cast(list[chat.ChatCompletionChunk], self.stream[self.index]) + response = MockAsyncStream(iter(indexed_stream)) else: - response = MockAsyncStream(iter(self.stream)) # type: ignore + response = MockAsyncStream(iter(cast(list[chat.ChatCompletionChunk], self.stream))) else: assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided' if isinstance(self.completions, list): diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index ff9d4fd3..3130ec34 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1,7 +1,6 @@ from __future__ import annotations as _annotations import json -from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime, timezone from functools import cached_property @@ -12,7 +11,6 @@ from pydantic import BaseModel from typing_extensions import TypedDict -from pydantic_ai import _utils from pydantic_ai.agent import Agent from pydantic_ai.exceptions import ModelRetry from pydantic_ai.messages import ( @@ -28,6 +26,7 @@ ) from ..conftest import IsNow, try_import +from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: from mistralai import ( @@ -61,23 +60,6 @@ ] -@dataclass -class MockAsyncStream: - _iter: Iterator[MistralCompletionChunk] - - async def __anext__(self) -> MistralCompletionChunk: - return _utils.sync_anext(self._iter) - - def __aiter__(self): - return self - - async def __aenter__(self): - return self - - async def __aexit__(self, *_args: Any): - pass - - @dataclass class MockMistralAI: completions: MistralChatCompletionResponse | list[MistralChatCompletionResponse] | None = None @@ -107,15 +89,24 @@ def create_stream_mock( async def chat_completions_create( # pragma: no cover self, *_args: Any, stream: bool = False, **_kwargs: Any - ) -> MistralChatCompletionResponse | MockAsyncStream | list[MistralChatCompletionResponse]: - response: MistralChatCompletionResponse | MockAsyncStream | list[MistralChatCompletionResponse] + ) -> ( + MistralChatCompletionResponse + | MockAsyncStream[MistralChatCompletionResponse] + | list[MistralChatCompletionResponse] + ): + response: ( + MistralChatCompletionResponse + | MockAsyncStream[MistralChatCompletionResponse] + | list[MistralChatCompletionResponse] + ) if stream or self.stream: assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' if isinstance(self.stream[0], list): - response = MockAsyncStream(iter(self.stream[self.index])) # pyright: ignore[reportArgumentType] + indexed_stream = cast(list[MistralChatCompletionResponse], self.stream[self.index]) + response = MockAsyncStream(iter(indexed_stream)) else: - response = MockAsyncStream(iter(self.stream)) # pyright: ignore[reportArgumentType] + response = MockAsyncStream(iter(cast(list[MistralChatCompletionResponse], self.stream))) else: assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided' diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 995984b3..7c068662 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import json -from collections.abc import Iterator, Sequence +from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timezone from functools import cached_property @@ -11,7 +11,7 @@ from inline_snapshot import snapshot from typing_extensions import TypedDict -from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, _utils +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ( ModelRequest, ModelResponse, @@ -24,6 +24,7 @@ from pydantic_ai.result import Usage from ..conftest import IsNow, try_import +from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: from openai import AsyncOpenAI @@ -62,23 +63,6 @@ def test_init_with_base_url(): m.name() -@dataclass -class MockAsyncStream: - _iter: Iterator[chat.ChatCompletionChunk] - - async def __anext__(self) -> chat.ChatCompletionChunk: - return _utils.sync_anext(self._iter) - - def __aiter__(self) -> MockAsyncStream: - return self - - async def __aenter__(self): - return self - - async def __aexit__(self, *_args: Any): - pass - - @dataclass class MockOpenAI: completions: chat.ChatCompletion | list[chat.ChatCompletion] | None = None @@ -102,14 +86,15 @@ def create_mock_stream( async def chat_completions_create( # pragma: no cover self, *_args: Any, stream: bool = False, **_kwargs: Any - ) -> chat.ChatCompletion | MockAsyncStream: + ) -> chat.ChatCompletion | MockAsyncStream[chat.ChatCompletionChunk]: if stream: assert self.stream is not None, 'you can only used `stream=True` if `stream` is provided' # noinspection PyUnresolvedReferences if isinstance(self.stream[0], list): - response = MockAsyncStream(iter(self.stream[self.index])) # type: ignore + indexed_stream = cast(list[chat.ChatCompletionChunk], self.stream[self.index]) + response = MockAsyncStream(iter(indexed_stream)) else: - response = MockAsyncStream(iter(self.stream)) # type: ignore + response = MockAsyncStream(iter(cast(list[chat.ChatCompletionChunk], self.stream))) else: assert self.completions is not None, 'you can only used `stream=False` if `completions` are provided' if isinstance(self.completions, list): From 626a69d2b76e5a8e508bfa6cbef804bf525471d3 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:44:07 -0700 Subject: [PATCH 31/34] Add some tests of _parts_manager.py --- tests/test_parts_manager.py | 171 ++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 tests/test_parts_manager.py diff --git a/tests/test_parts_manager.py b/tests/test_parts_manager.py new file mode 100644 index 00000000..627c1c54 --- /dev/null +++ b/tests/test_parts_manager.py @@ -0,0 +1,171 @@ +from __future__ import annotations as _annotations + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai._parts_manager import ModelResponsePartsManager +from pydantic_ai.messages import ( + ArgsJson, + PartDeltaEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ToolCallPart, + ToolCallPartDelta, +) + + +@pytest.mark.parametrize('vendor_part_id', [None, 'content']) +def test_handle_text_deltas(vendor_part_id: str | None): + manager = ModelResponsePartsManager() + assert manager.get_parts() == [] + + event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ') + assert event == snapshot( + PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) + + event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world') + assert event == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot([TextPart(content='hello world', part_kind='text')]) + + +def test_handle_dovetailed_text_deltas(): + manager = ModelResponsePartsManager() + + event = manager.handle_text_delta(vendor_part_id='first', content='hello ') + assert event == snapshot( + PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) + + event = manager.handle_text_delta(vendor_part_id='second', content='goodbye ') + assert event == snapshot( + PartStartEvent(index=1, part=TextPart(content='goodbye ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot( + [TextPart(content='hello ', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] + ) + + event = manager.handle_text_delta(vendor_part_id='first', content='world') + assert event == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot( + [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye ', part_kind='text')] + ) + + event = manager.handle_text_delta(vendor_part_id='second', content='Samuel') + assert event == snapshot( + PartDeltaEvent( + index=1, delta=TextPartDelta(content_delta='Samuel', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot( + [TextPart(content='hello world', part_kind='text'), TextPart(content='goodbye Samuel', part_kind='text')] + ) + + +def test_handle_tool_call_deltas(): + manager = ModelResponsePartsManager() + + event = manager.handle_tool_call_delta( + vendor_part_id='first', tool_name='tool1', args='{"arg1":', tool_call_id=None + ) + assert event == snapshot( + PartStartEvent( + index=0, + part=ToolCallPart( + tool_name='tool1', args=ArgsJson(args_json='{"arg1":'), tool_call_id=None, part_kind='tool-call' + ), + event_kind='part_start', + ) + ) + assert manager.get_parts() == snapshot( + [ToolCallPart(tool_name='tool1', args=ArgsJson(args_json='{"arg1":'), tool_call_id=None, part_kind='tool-call')] + ) + + event = manager.handle_tool_call_delta(vendor_part_id='first', tool_name=None, args='"value1"}', tool_call_id=None) + assert event == snapshot( + PartDeltaEvent( + index=0, + delta=ToolCallPartDelta( + tool_name_delta=None, args_delta='"value1"}', tool_call_id=None, part_delta_kind='tool_call' + ), + event_kind='part_delta', + ) + ) + assert manager.get_parts() == snapshot( + [ + ToolCallPart( + tool_name='tool1', + args=ArgsJson(args_json='{"arg1":"value1"}'), + tool_call_id=None, + part_kind='tool-call', + ) + ] + ) + + +@pytest.mark.parametrize('vendor_part_id', [None, 'content']) +def test_handle_mixed_deltas_without_text_part_id(vendor_part_id: str | None): + manager = ModelResponsePartsManager() + + event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='hello ') + assert event == snapshot( + PartStartEvent(index=0, part=TextPart(content='hello ', part_kind='text'), event_kind='part_start') + ) + assert manager.get_parts() == snapshot([TextPart(content='hello ', part_kind='text')]) + + event = manager.handle_tool_call_delta( + vendor_part_id='first_tool_call', tool_name='tool1', args='{"arg1":', tool_call_id='abc' + ) + assert event == snapshot( + PartStartEvent( + index=1, + part=ToolCallPart( + tool_name='tool1', args=ArgsJson(args_json='{"arg1":'), tool_call_id='abc', part_kind='tool-call' + ), + event_kind='part_start', + ) + ) + + event = manager.handle_text_delta(vendor_part_id=vendor_part_id, content='world') + if vendor_part_id is None: + assert event == snapshot( + PartStartEvent( + index=2, + part=TextPart(content='world', part_kind='text'), + event_kind='part_start', + ) + ) + assert manager.get_parts() == snapshot( + [ + TextPart(content='hello ', part_kind='text'), + ToolCallPart( + tool_name='tool1', args=ArgsJson(args_json='{"arg1":'), tool_call_id='abc', part_kind='tool-call' + ), + TextPart(content='world', part_kind='text'), + ] + ) + else: + assert event == snapshot( + PartDeltaEvent( + index=0, delta=TextPartDelta(content_delta='world', part_delta_kind='text'), event_kind='part_delta' + ) + ) + assert manager.get_parts() == snapshot( + [ + TextPart(content='hello world', part_kind='text'), + ToolCallPart( + tool_name='tool1', args=ArgsJson(args_json='{"arg1":'), tool_call_id='abc', part_kind='tool-call' + ), + ] + ) From 08a2f4fedb90128db609a2ff7d5c5d623e736230 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Thu, 9 Jan 2025 12:32:11 -0800 Subject: [PATCH 32/34] Initial support for anthropic streaming --- .../pydantic_ai/models/anthropic.py | 99 ++++++++-- tests/models/test_anthropic.py | 169 +++++++++++++++++- 2 files changed, 244 insertions(+), 24 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index d62f6ccb..5921f38a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -1,14 +1,16 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator +from collections.abc import AsyncIterable, AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field -from typing import Any, Literal, Union, cast, overload +from datetime import datetime, timezone +from json import JSONDecodeError, loads as json_loads +from typing import Any, Dict, Literal, Union, cast, overload from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import usage +from .. import UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ArgsDict, @@ -16,6 +18,7 @@ ModelRequest, ModelResponse, ModelResponsePart, + ModelResponseStreamEvent, RetryPromptPart, SystemPromptPart, TextPart, @@ -38,11 +41,16 @@ from anthropic.types import ( Message as AnthropicMessage, MessageParam, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, RawMessageDeltaEvent, RawMessageStartEvent, + RawMessageStopEvent, RawMessageStreamEvent, TextBlock, TextBlockParam, + TextDelta, ToolChoiceParam, ToolParam, ToolResultBlockParam, @@ -231,22 +239,15 @@ def _process_response(response: AnthropicMessage) -> ModelResponse: @staticmethod async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse: - """TODO: Process a streamed response, and prepare a streaming response to return.""" - # We don't yet support streamed responses from Anthropic, so we raise an error here for now. - # Streamed responses will be supported in a future release. - - raise RuntimeError('Streamed responses are not yet supported for Anthropic models.') - - # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse - # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following: - # RawMessageStartEvent - # RawMessageDeltaEvent - # RawMessageStopEvent - # RawContentBlockStartEvent - # RawContentBlockDeltaEvent - # RawContentBlockDeltaEvent - # - # We might refactor streaming internally before we implement this... + """Process a streamed response, and prepare a streaming response to return.""" + peekable_response = _utils.PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if isinstance(first_chunk, _utils.Unset): + raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') + + # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time + timestamp = datetime.now(tz=timezone.utc) + return AnthropicStreamedResponse(peekable_response, timestamp) @staticmethod def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]: @@ -342,3 +343,63 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage response_tokens=response_usage.output_tokens, total_tokens=(request_tokens or 0) + response_usage.output_tokens, ) + + +@dataclass +class AnthropicStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for Anthropic models.""" + + _response: AsyncIterable[RawMessageStreamEvent] + _timestamp: datetime + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + current_block: TextBlock | ToolUseBlock | None = None + current_json: str = '' + + async for event in self._response: + self._usage += _map_usage(event) + + if isinstance(event, RawContentBlockStartEvent): + current_block = event.content_block + if isinstance(current_block, TextBlock) and current_block.text: + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text) + elif isinstance(current_block, ToolUseBlock): + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=current_block.id, + tool_name=current_block.name, + args=cast(Dict[str, Any], current_block.input), + tool_call_id=current_block.id, + ) + if maybe_event is not None: + yield maybe_event + + elif isinstance(event, RawContentBlockDeltaEvent): + if isinstance(event.delta, TextDelta): + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text) + elif ( + current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock) + ): + # Try to parse the JSON immediately, otherwise cache the value for later. This handles + # cases where the JSON is not currently valid but will be valid once we stream more tokens. + try: + parsed_args = json_loads(current_json + event.delta.partial_json) + current_json = '' + except JSONDecodeError: + current_json += event.delta.partial_json + continue + + # For tool calls, we need to handle partial JSON updates + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=current_block.id, + tool_name='', + args=parsed_args, + tool_call_id=current_block.id, + ) + if maybe_event is not None: + yield maybe_event + + elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)): + current_block = None + + def timestamp(self) -> datetime: + return self._timestamp diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 66a16605..2a05e903 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1,10 +1,11 @@ from __future__ import annotations as _annotations import json +from collections.abc import AsyncIterator from dataclasses import dataclass from datetime import timezone from functools import cached_property -from typing import Any, cast +from typing import Any, TypeVar, cast import pytest from inline_snapshot import snapshot @@ -25,14 +26,24 @@ from ..conftest import IsNow, try_import with try_import() as imports_successful: - from anthropic import AsyncAnthropic + from anthropic import AsyncAnthropic, AsyncStream from anthropic.types import ( ContentBlock, + InputJSONDelta, Message as AnthropicMessage, + MessageDeltaUsage, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockStopEvent, + RawMessageDeltaEvent, + RawMessageStartEvent, + RawMessageStopEvent, + RawMessageStreamEvent, TextBlock, ToolUseBlock, Usage as AnthropicUsage, ) + from anthropic.types.raw_message_delta_event import Delta from pydantic_ai.models.anthropic import AnthropicModel @@ -41,6 +52,9 @@ pytest.mark.anyio, ] +# Type variable for generic AsyncStream +T = TypeVar('T') + def test_init(): m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar') @@ -48,9 +62,38 @@ def test_init(): assert m.name() == 'anthropic:claude-3-5-haiku-latest' +class MockAsyncStream(AsyncStream[T]): + """Mock implementation of AsyncStream for testing.""" + + def __init__(self, events: list[list[T]]): + self.events = events + self.stream_index = 0 + + def __aiter__(self) -> AsyncIterator[T]: + if self.stream_index >= len(self.events): + raise StopAsyncIteration + + async def iterator() -> AsyncIterator[T]: + current_stream = self.events[self.stream_index] + for event in current_stream: + yield event + self.stream_index += 1 + + return iterator() + + async def __anext__(self) -> T: + return await self._iterator.__anext__() + + async def __aenter__(self) -> MockAsyncStream[T]: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass + + @dataclass class MockAnthropic: - messages_: AnthropicMessage | list[AnthropicMessage] | None = None + messages_: AnthropicMessage | list[AnthropicMessage] | AsyncStream[RawMessageStreamEvent] | None = None index = 0 @cached_property @@ -58,11 +101,18 @@ def messages(self) -> Any: return type('Messages', (), {'create': self.messages_create}) @classmethod - def create_mock(cls, messages_: AnthropicMessage | list[AnthropicMessage]) -> AsyncAnthropic: + def create_mock( + cls, messages_: AnthropicMessage | list[AnthropicMessage] | AsyncStream[RawMessageStreamEvent] + ) -> AsyncAnthropic: return cast(AsyncAnthropic, cls(messages_=messages_)) - async def messages_create(self, *_args: Any, **_kwargs: Any) -> AnthropicMessage: + async def messages_create( + self, *_args: Any, stream: bool = False, **_kwargs: Any + ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]: assert self.messages_ is not None, '`messages` must be provided' + if isinstance(self.messages_, AsyncStream): + assert stream, 'stream must be True when using AsyncStream' + return self.messages_ if isinstance(self.messages_, list): response = self.messages_[self.index] else: @@ -241,3 +291,112 @@ async def get_location(loc_name: str) -> str: ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) + + +async def test_stream_structured(allow_model_requests: None): + """Test streaming structured responses with Anthropic's API. + + This test simulates how Anthropic streams tool calls: + 1. Message start + 2. Tool block start with initial data + 3. Tool block delta with additional data + 4. Tool block stop + 5. Update usage + 6. Message stop + """ + stream: list[RawMessageStreamEvent] = [ + RawMessageStartEvent( + type='message_start', + message=AnthropicMessage( + id='msg_123', + model='claude-3-5-haiku-latest', + role='assistant', + type='message', + content=[], + stop_reason=None, + usage=AnthropicUsage(input_tokens=20, output_tokens=0), + ), + ), + # Start tool block with initial data + RawContentBlockStartEvent( + type='content_block_start', + index=0, + content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={'first': 'One'}), + ), + # Add more data through an incomplete JSON delta + RawContentBlockDeltaEvent( + type='content_block_delta', + index=0, + delta=InputJSONDelta(type='input_json_delta', partial_json='{"second":'), + ), + RawContentBlockDeltaEvent( + type='content_block_delta', + index=0, + delta=InputJSONDelta(type='input_json_delta', partial_json='"Two"}'), + ), + # Mark tool block as complete + RawContentBlockStopEvent(type='content_block_stop', index=0), + # Update the top-level message with usage + RawMessageDeltaEvent( + type='message_delta', + delta=Delta( + stop_reason='end_turn', + ), + usage=MessageDeltaUsage( + output_tokens=5, + ), + ), + # Mark message as complete + RawMessageStopEvent(type='message_stop'), + ] + + done_stream: list[RawMessageStreamEvent] = [ + RawMessageStartEvent( + type='message_start', + message=AnthropicMessage( + id='msg_123', + model='claude-3-5-haiku-latest', + role='assistant', + type='message', + content=[], + stop_reason=None, + usage=AnthropicUsage(input_tokens=0, output_tokens=0), + ), + ), + # Text block with final data + RawContentBlockStartEvent( + type='content_block_start', + index=0, + content_block=TextBlock(type='text', text='FINAL_PAYLOAD'), + ), + RawContentBlockStopEvent(type='content_block_stop', index=0), + RawMessageStopEvent(type='message_stop'), + ] + + mock_client = MockAnthropic.create_mock(MockAsyncStream([stream, done_stream])) + m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client) + agent = Agent(m) + + tool_called = False + + @agent.tool_plain + async def my_tool(first: str, second: str) -> int: + nonlocal tool_called + tool_called = True + return len(first) + len(second) + + async with agent.run_stream('') as result: + assert not result.is_complete + chunks = [c async for c in result.stream(debounce_by=None)] + + # The tool output doesn't echo any content to the stream, so we only get the final payload once when + # the block starts and once when it ends. + assert chunks == snapshot( + [ + 'FINAL_PAYLOAD', + 'FINAL_PAYLOAD', + ] + ) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=2, request_tokens=20, response_tokens=5, total_tokens=25)) + assert tool_called From 57e0ee9d6195b010a24a91d44f936248ca9648c2 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Fri, 10 Jan 2025 09:50:12 -0800 Subject: [PATCH 33/34] MockAsyncStream requires global stream --- .../pydantic_ai/models/anthropic.py | 2 +- tests/models/test_anthropic.py | 42 ++++++++++--------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 2f0d5d2b..0fda7a33 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -3,7 +3,7 @@ from collections.abc import AsyncIterable, AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime from json import JSONDecodeError, loads as json_loads from typing import Any, Dict, Literal, Union, cast, overload diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 2a05e903..0bcc72c9 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -62,33 +62,35 @@ def test_init(): assert m.name() == 'anthropic:claude-3-5-haiku-latest' -class MockAsyncStream(AsyncStream[T]): - """Mock implementation of AsyncStream for testing.""" +if imports_successful(): - def __init__(self, events: list[list[T]]): - self.events = events - self.stream_index = 0 + class MockAsyncStream(AsyncStream[T]): + """Mock implementation of AsyncStream for testing.""" - def __aiter__(self) -> AsyncIterator[T]: - if self.stream_index >= len(self.events): - raise StopAsyncIteration + def __init__(self, events: list[list[T]]): + self.events = events + self.stream_index = 0 - async def iterator() -> AsyncIterator[T]: - current_stream = self.events[self.stream_index] - for event in current_stream: - yield event - self.stream_index += 1 + def __aiter__(self) -> AsyncIterator[T]: + if self.stream_index >= len(self.events): + raise StopAsyncIteration - return iterator() + async def iterator() -> AsyncIterator[T]: + current_stream = self.events[self.stream_index] + for event in current_stream: + yield event + self.stream_index += 1 - async def __anext__(self) -> T: - return await self._iterator.__anext__() + return iterator() - async def __aenter__(self) -> MockAsyncStream[T]: - return self + async def __anext__(self) -> T: + return await self._iterator.__anext__() - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - pass + async def __aenter__(self) -> MockAsyncStream[T]: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass @dataclass From 256a3d95152f8329b0f31aa8b2b6c5bfd701a31f Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Fri, 10 Jan 2025 09:53:33 -0800 Subject: [PATCH 34/34] Restore anthropic stream processor --- .../pydantic_ai/models/anthropic.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 0fda7a33..1c131382 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -3,14 +3,14 @@ from collections.abc import AsyncIterable, AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from json import JSONDecodeError, loads as json_loads -from typing import Any, Dict, Literal, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never -from .. import usage +from .. import UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( ArgsDict, @@ -239,22 +239,14 @@ def _process_response(response: AnthropicMessage) -> ModelResponse: @staticmethod async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse: - """TODO: Process a streamed response, and prepare a streaming response to return.""" - # We don't yet support streamed responses from Anthropic, so we raise an error here for now. - # Streamed responses will be supported in a future release. - - raise RuntimeError('Streamed responses are not yet supported for Anthropic models.') - - # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse - # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following: - # RawMessageStartEvent - # RawMessageDeltaEvent - # RawMessageStopEvent - # RawContentBlockStartEvent - # RawContentBlockDeltaEvent - # RawContentBlockDeltaEvent - # - # We might refactor streaming internally before we implement this... + peekable_response = _utils.PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if isinstance(first_chunk, _utils.Unset): + raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') + + # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time + timestamp = datetime.now(tz=timezone.utc) + return AnthropicStreamedResponse(peekable_response, timestamp) @staticmethod def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]: @@ -374,7 +366,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=current_block.id, tool_name=current_block.name, - args=cast(Dict[str, Any], current_block.input), + args=cast(dict[str, Any], current_block.input), tool_call_id=current_block.id, ) if maybe_event is not None: