From d913ba35eb3a95e80154b9d35e4c0a9f4a8dfeb1 Mon Sep 17 00:00:00 2001
From: Robert Craigie <robert@craigie.dev>
Date: Mon, 13 Jan 2025 15:20:57 +0000
Subject: [PATCH] feat(beta): add streaming helpers for beta messages (#819)

---
 src/anthropic/lib/streaming/__init__.py       |  13 +
 src/anthropic/lib/streaming/_beta_messages.py | 385 ++++++++++++++++++
 src/anthropic/lib/streaming/_beta_types.py    |  65 +++
 .../resources/beta/messages/messages.py       | 121 ++++++
 4 files changed, 584 insertions(+)
 create mode 100644 src/anthropic/lib/streaming/_beta_messages.py
 create mode 100644 src/anthropic/lib/streaming/_beta_types.py

diff --git a/src/anthropic/lib/streaming/__init__.py b/src/anthropic/lib/streaming/__init__.py
index 0ab41209..103fff58 100644
--- a/src/anthropic/lib/streaming/__init__.py
+++ b/src/anthropic/lib/streaming/__init__.py
@@ -11,3 +11,16 @@
     MessageStreamManager as MessageStreamManager,
     AsyncMessageStreamManager as AsyncMessageStreamManager,
 )
+from ._beta_types import (
+    BetaTextEvent as BetaTextEvent,
+    BetaInputJsonEvent as BetaInputJsonEvent,
+    BetaMessageStopEvent as BetaMessageStopEvent,
+    BetaMessageStreamEvent as BetaMessageStreamEvent,
+    BetaContentBlockStopEvent as BetaContentBlockStopEvent,
+)
+from ._beta_messages import (
+    BetaMessageStream as BetaMessageStream,
+    BetaAsyncMessageStream as BetaAsyncMessageStream,
+    BetaMessageStreamManager as BetaMessageStreamManager,
+    BetaAsyncMessageStreamManager as BetaAsyncMessageStreamManager,
+)
diff --git a/src/anthropic/lib/streaming/_beta_messages.py b/src/anthropic/lib/streaming/_beta_messages.py
new file mode 100644
index 00000000..48e419e9
--- /dev/null
+++ b/src/anthropic/lib/streaming/_beta_messages.py
@@ -0,0 +1,385 @@
+from __future__ import annotations
+
+from types import TracebackType
+from typing import TYPE_CHECKING, Any, Callable, cast
+from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never
+
+import httpx
+
+from ..._utils import consume_sync_iterator, consume_async_iterator
+from ..._models import build, construct_type
+from ._beta_types import (
+    BetaTextEvent,
+    BetaInputJsonEvent,
+    BetaMessageStopEvent,
+    BetaMessageStreamEvent,
+    BetaContentBlockStopEvent,
+)
+from ..._streaming import Stream, AsyncStream
+from ...types.beta import BetaMessage, BetaContentBlock, BetaRawMessageStreamEvent
+
+
+class BetaMessageStream:
+    text_stream: Iterator[str]
+    """Iterator over just the text deltas in the stream.
+
+    ```py
+    for text in stream.text_stream:
+        print(text, end="", flush=True)
+    print()
+    ```
+    """
+
+    def __init__(self, raw_stream: Stream[BetaRawMessageStreamEvent]) -> None:
+        self._raw_stream = raw_stream
+        self.text_stream = self.__stream_text__()
+        self._iterator = self.__stream__()
+        self.__final_message_snapshot: BetaMessage | None = None
+
+    @property
+    def response(self) -> httpx.Response:
+        return self._raw_stream.response
+
+    def __next__(self) -> BetaMessageStreamEvent:
+        return self._iterator.__next__()
+
+    def __iter__(self) -> Iterator[BetaMessageStreamEvent]:
+        for item in self._iterator:
+            yield item
+
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        self.close()
+
+    def close(self) -> None:
+        """
+        Close the response and release the connection.
+
+        Automatically called if the response body is read to completion.
+        """
+        self._raw_stream.close()
+
+    def get_final_message(self) -> BetaMessage:
+        """Waits until the stream has been read to completion and returns
+        the accumulated `Message` object.
+        """
+        self.until_done()
+        assert self.__final_message_snapshot is not None
+        return self.__final_message_snapshot
+
+    def get_final_text(self) -> str:
+        """Returns all `text` content blocks concatenated together.
+
+        > [!NOTE]
+        > Currently the API will only respond with a single content block.
+
+        Will raise an error if no `text` content blocks were returned.
+        """
+        message = self.get_final_message()
+        text_blocks: list[str] = []
+        for block in message.content:
+            if block.type == "text":
+                text_blocks.append(block.text)
+
+        if not text_blocks:
+            raise RuntimeError("Expected to have received at least 1 text block")
+
+        return "".join(text_blocks)
+
+    def until_done(self) -> None:
+        """Blocks until the stream has been consumed"""
+        consume_sync_iterator(self)
+
+    # properties
+    @property
+    def current_message_snapshot(self) -> BetaMessage:
+        assert self.__final_message_snapshot is not None
+        return self.__final_message_snapshot
+
+    def __stream__(self) -> Iterator[BetaMessageStreamEvent]:
+        for sse_event in self._raw_stream:
+            self.__final_message_snapshot = accumulate_event(
+                event=sse_event,
+                current_snapshot=self.__final_message_snapshot,
+            )
+
+            events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
+            for event in events_to_fire:
+                yield event
+
+    def __stream_text__(self) -> Iterator[str]:
+        for chunk in self:
+            if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
+                yield chunk.delta.text
+
+
+class BetaMessageStreamManager:
+    """Wrapper over MessageStream that is returned by `.stream()`.
+
+    ```py
+    with client.beta.messages.stream(...) as stream:
+        for chunk in stream:
+            ...
+    ```
+    """
+
+    def __init__(
+        self,
+        api_request: Callable[[], Stream[BetaRawMessageStreamEvent]],
+    ) -> None:
+        self.__stream: BetaMessageStream | None = None
+        self.__api_request = api_request
+
+    def __enter__(self) -> BetaMessageStream:
+        raw_stream = self.__api_request()
+        self.__stream = BetaMessageStream(raw_stream)
+        return self.__stream
+
+    def __exit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        if self.__stream is not None:
+            self.__stream.close()
+
+
+class BetaAsyncMessageStream:
+    text_stream: AsyncIterator[str]
+    """Async iterator over just the text deltas in the stream.
+
+    ```py
+    async for text in stream.text_stream:
+        print(text, end="", flush=True)
+    print()
+    ```
+    """
+
+    def __init__(self, raw_stream: AsyncStream[BetaRawMessageStreamEvent]) -> None:
+        self._raw_stream = raw_stream
+        self.text_stream = self.__stream_text__()
+        self._iterator = self.__stream__()
+        self.__final_message_snapshot: BetaMessage | None = None
+
+    @property
+    def response(self) -> httpx.Response:
+        return self._raw_stream.response
+
+    async def __anext__(self) -> BetaMessageStreamEvent:
+        return await self._iterator.__anext__()
+
+    async def __aiter__(self) -> AsyncIterator[BetaMessageStreamEvent]:
+        async for item in self._iterator:
+            yield item
+
+    async def __aenter__(self) -> Self:
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        await self.close()
+
+    async def close(self) -> None:
+        """
+        Close the response and release the connection.
+
+        Automatically called if the response body is read to completion.
+        """
+        await self._raw_stream.close()
+
+    async def get_final_message(self) -> BetaMessage:
+        """Waits until the stream has been read to completion and returns
+        the accumulated `Message` object.
+        """
+        await self.until_done()
+        assert self.__final_message_snapshot is not None
+        return self.__final_message_snapshot
+
+    async def get_final_text(self) -> str:
+        """Returns all `text` content blocks concatenated together.
+
+        > [!NOTE]
+        > Currently the API will only respond with a single content block.
+
+        Will raise an error if no `text` content blocks were returned.
+        """
+        message = await self.get_final_message()
+        text_blocks: list[str] = []
+        for block in message.content:
+            if block.type == "text":
+                text_blocks.append(block.text)
+
+        if not text_blocks:
+            raise RuntimeError("Expected to have received at least 1 text block")
+
+        return "".join(text_blocks)
+
+    async def until_done(self) -> None:
+        """Waits until the stream has been consumed"""
+        await consume_async_iterator(self)
+
+    # properties
+    @property
+    def current_message_snapshot(self) -> BetaMessage:
+        assert self.__final_message_snapshot is not None
+        return self.__final_message_snapshot
+
+    async def __stream__(self) -> AsyncIterator[BetaMessageStreamEvent]:
+        async for sse_event in self._raw_stream:
+            self.__final_message_snapshot = accumulate_event(
+                event=sse_event,
+                current_snapshot=self.__final_message_snapshot,
+            )
+
+            events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
+            for event in events_to_fire:
+                yield event
+
+    async def __stream_text__(self) -> AsyncIterator[str]:
+        async for chunk in self:
+            if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
+                yield chunk.delta.text
+
+
+class BetaAsyncMessageStreamManager:
+    """Wrapper over BetaAsyncMessageStream that is returned by `.stream()`
+    so that an async context manager can be used without `await`ing the
+    original client call.
+
+    ```py
+    async with client.beta.messages.stream(...) as stream:
+        async for chunk in stream:
+            ...
+    ```
+    """
+
+    def __init__(
+        self,
+        api_request: Awaitable[AsyncStream[BetaRawMessageStreamEvent]],
+    ) -> None:
+        self.__stream: BetaAsyncMessageStream | None = None
+        self.__api_request = api_request
+
+    async def __aenter__(self) -> BetaAsyncMessageStream:
+        raw_stream = await self.__api_request
+        self.__stream = BetaAsyncMessageStream(raw_stream)
+        return self.__stream
+
+    async def __aexit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc: BaseException | None,
+        exc_tb: TracebackType | None,
+    ) -> None:
+        if self.__stream is not None:
+            await self.__stream.close()
+
+
+def build_events(
+    *,
+    event: BetaRawMessageStreamEvent,
+    message_snapshot: BetaMessage,
+) -> list[BetaMessageStreamEvent]:
+    events_to_fire: list[BetaMessageStreamEvent] = []
+
+    if event.type == "message_start":
+        events_to_fire.append(event)
+    elif event.type == "message_delta":
+        events_to_fire.append(event)
+    elif event.type == "message_stop":
+        events_to_fire.append(build(BetaMessageStopEvent, type="message_stop", message=message_snapshot))
+    elif event.type == "content_block_start":
+        events_to_fire.append(event)
+    elif event.type == "content_block_delta":
+        events_to_fire.append(event)
+
+        content_block = message_snapshot.content[event.index]
+        if event.delta.type == "text_delta" and content_block.type == "text":
+            events_to_fire.append(
+                build(
+                    BetaTextEvent,
+                    type="text",
+                    text=event.delta.text,
+                    snapshot=content_block.text,
+                )
+            )
+        elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
+            events_to_fire.append(
+                build(
+                    BetaInputJsonEvent,
+                    type="input_json",
+                    partial_json=event.delta.partial_json,
+                    snapshot=content_block.input,
+                )
+            )
+    elif event.type == "content_block_stop":
+        content_block = message_snapshot.content[event.index]
+
+        events_to_fire.append(
+            build(BetaContentBlockStopEvent, type="content_block_stop", index=event.index, content_block=content_block),
+        )
+    else:
+        # we only want exhaustive checking for linters, not at runtime
+        if TYPE_CHECKING:  # type: ignore[unreachable]
+            assert_never(event)
+
+    return events_to_fire
+
+
+JSON_BUF_PROPERTY = "__json_buf"
+
+
+def accumulate_event(
+    *,
+    event: BetaRawMessageStreamEvent,
+    current_snapshot: BetaMessage | None,
+) -> BetaMessage:
+    if current_snapshot is None:
+        if event.type == "message_start":
+            return BetaMessage.construct(**cast(Any, event.message.to_dict()))
+
+        raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"')
+
+    if event.type == "content_block_start":
+        # TODO: check index
+        current_snapshot.content.append(
+            cast(
+                BetaContentBlock,
+                construct_type(type_=BetaContentBlock, value=event.content_block.model_dump()),
+            ),
+        )
+    elif event.type == "content_block_delta":
+        content = current_snapshot.content[event.index]
+        if content.type == "text" and event.delta.type == "text_delta":
+            content.text += event.delta.text
+        elif content.type == "tool_use" and event.delta.type == "input_json_delta":
+            from jiter import from_json
+
+            # we need to keep track of the raw JSON string as well so that we can
+            # re-parse it for each delta, for now we just store it as an untyped
+            # property on the snapshot
+            json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
+            json_buf += bytes(event.delta.partial_json, "utf-8")
+
+            if json_buf:
+                content.input = from_json(json_buf, partial_mode=True)
+
+            setattr(content, JSON_BUF_PROPERTY, json_buf)
+    elif event.type == "message_delta":
+        current_snapshot.stop_reason = event.delta.stop_reason
+        current_snapshot.stop_sequence = event.delta.stop_sequence
+        current_snapshot.usage.output_tokens = event.usage.output_tokens
+
+    return current_snapshot
diff --git a/src/anthropic/lib/streaming/_beta_types.py b/src/anthropic/lib/streaming/_beta_types.py
new file mode 100644
index 00000000..a2a0bf6b
--- /dev/null
+++ b/src/anthropic/lib/streaming/_beta_types.py
@@ -0,0 +1,65 @@
+from typing import Union
+from typing_extensions import Literal
+
+from ..._models import BaseModel
+from ...types.beta import (
+    BetaMessage,
+    BetaContentBlock,
+    BetaRawMessageStopEvent,
+    BetaRawMessageDeltaEvent,
+    BetaRawMessageStartEvent,
+    BetaRawContentBlockStopEvent,
+    BetaRawContentBlockDeltaEvent,
+    BetaRawContentBlockStartEvent,
+)
+
+
+class BetaTextEvent(BaseModel):
+    type: Literal["text"]
+
+    text: str
+    """The text delta"""
+
+    snapshot: str
+    """The entire accumulated text"""
+
+
+class BetaInputJsonEvent(BaseModel):
+    type: Literal["input_json"]
+
+    partial_json: str
+    """A partial JSON string delta
+
+    e.g. `'"San Francisco,'`
+    """
+
+    snapshot: object
+    """The currently accumulated parsed object.
+
+
+    e.g. `{'location': 'San Francisco, CA'}`
+    """
+
+
+class BetaMessageStopEvent(BetaRawMessageStopEvent):
+    type: Literal["message_stop"]
+
+    message: BetaMessage
+
+
+class BetaContentBlockStopEvent(BetaRawContentBlockStopEvent):
+    type: Literal["content_block_stop"]
+
+    content_block: BetaContentBlock
+
+
+BetaMessageStreamEvent = Union[
+    BetaTextEvent,
+    BetaInputJsonEvent,
+    BetaRawMessageStartEvent,
+    BetaRawMessageDeltaEvent,
+    BetaMessageStopEvent,
+    BetaRawContentBlockStartEvent,
+    BetaRawContentBlockDeltaEvent,
+    BetaContentBlockStopEvent,
+]
diff --git a/src/anthropic/resources/beta/messages/messages.py b/src/anthropic/resources/beta/messages/messages.py
index 62582d47..7aa89187 100644
--- a/src/anthropic/resources/beta/messages/messages.py
+++ b/src/anthropic/resources/beta/messages/messages.py
@@ -3,6 +3,7 @@
 from __future__ import annotations
 
 from typing import List, Union, Iterable
+from functools import partial
 from itertools import chain
 from typing_extensions import Literal, overload
 
@@ -32,6 +33,7 @@
 from ...._streaming import Stream, AsyncStream
 from ....types.beta import message_create_params, message_count_tokens_params
 from ...._base_client import make_request_options
+from ....lib.streaming import BetaMessageStreamManager, BetaAsyncMessageStreamManager
 from ....types.model_param import ModelParam
 from ....types.beta.beta_message import BetaMessage
 from ....types.anthropic_beta_param import AnthropicBetaParam
@@ -922,6 +924,67 @@ def create(
             stream_cls=Stream[BetaRawMessageStreamEvent],
         )
 
+    def stream(
+        self,
+        *,
+        max_tokens: int,
+        messages: Iterable[BetaMessageParam],
+        model: ModelParam,
+        metadata: BetaMetadataParam | NotGiven = NOT_GIVEN,
+        stop_sequences: List[str] | NotGiven = NOT_GIVEN,
+        system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN,
+        temperature: float | NotGiven = NOT_GIVEN,
+        tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN,
+        tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN,
+        top_k: int | NotGiven = NOT_GIVEN,
+        top_p: float | NotGiven = NOT_GIVEN,
+        betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN,
+        # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+        # The extra values given here take precedence over values defined on the client or passed to this method.
+        extra_headers: Headers | None = None,
+        extra_query: Query | None = None,
+        extra_body: Body | None = None,
+        timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> BetaMessageStreamManager:
+        """Create a Message stream"""
+        if not is_given(timeout) and self._client.timeout == DEFAULT_TIMEOUT:
+            timeout = 600
+
+        extra_headers = {
+            "X-Stainless-Stream-Helper": "beta.messages",
+            **strip_not_given({"anthropic-beta": ",".join(str(e) for e in betas) if is_given(betas) else NOT_GIVEN}),
+            **(extra_headers or {}),
+        }
+        make_request = partial(
+            self._post,
+            "/v1/messages?beta=true",
+            body=maybe_transform(
+                {
+                    "max_tokens": max_tokens,
+                    "messages": messages,
+                    "model": model,
+                    "metadata": metadata,
+                    "stop_sequences": stop_sequences,
+                    "system": system,
+                    "temperature": temperature,
+                    "top_k": top_k,
+                    "top_p": top_p,
+                    "tools": tools,
+                    "tool_choice": tool_choice,
+                    "stream": True,
+                },
+                message_create_params.MessageCreateParams,
+            ),
+            options=make_request_options(
+                extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+            ),
+            cast_to=BetaMessage,
+            stream=True,
+            stream_cls=Stream[BetaRawMessageStreamEvent],
+        )
+        return BetaMessageStreamManager(make_request)
+    
+
     def count_tokens(
         self,
         *,
@@ -2030,6 +2093,64 @@ async def create(
             stream_cls=AsyncStream[BetaRawMessageStreamEvent],
         )
 
+    def stream(
+        self,
+        *,
+        max_tokens: int,
+        messages: Iterable[BetaMessageParam],
+        model: ModelParam,
+        metadata: BetaMetadataParam | NotGiven = NOT_GIVEN,
+        stop_sequences: List[str] | NotGiven = NOT_GIVEN,
+        system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN,
+        temperature: float | NotGiven = NOT_GIVEN,
+        tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN,
+        tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN,
+        top_k: int | NotGiven = NOT_GIVEN,
+        top_p: float | NotGiven = NOT_GIVEN,
+        betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN,
+        # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+        # The extra values given here take precedence over values defined on the client or passed to this method.
+        extra_headers: Headers | None = None,
+        extra_query: Query | None = None,
+        extra_body: Body | None = None,
+        timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+    ) -> BetaAsyncMessageStreamManager:
+        if not is_given(timeout) and self._client.timeout == DEFAULT_TIMEOUT:
+            timeout = 600
+
+        extra_headers = {
+            "X-Stainless-Stream-Helper": "beta.messages",
+            **strip_not_given({"anthropic-beta": ",".join(str(e) for e in betas) if is_given(betas) else NOT_GIVEN}),
+            **(extra_headers or {}),
+        }
+        request = self._post(
+            "/v1/messages",
+            body=maybe_transform(
+                {
+                    "max_tokens": max_tokens,
+                    "messages": messages,
+                    "model": model,
+                    "metadata": metadata,
+                    "stop_sequences": stop_sequences,
+                    "system": system,
+                    "temperature": temperature,
+                    "top_k": top_k,
+                    "top_p": top_p,
+                    "tools": tools,
+                    "tool_choice": tool_choice,
+                    "stream": True,
+                },
+                message_create_params.MessageCreateParams,
+            ),
+            options=make_request_options(
+                extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+            ),
+            cast_to=BetaMessage,
+            stream=True,
+            stream_cls=AsyncStream[BetaRawMessageStreamEvent],
+        )
+        return BetaAsyncMessageStreamManager(request)
+
     async def count_tokens(
         self,
         *,