Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(streaming): accumulate citations #844

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 64 additions & 31 deletions src/anthropic/lib/streaming/_beta_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never

import httpx
from pydantic import BaseModel

from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._models import build, construct_type
from ._beta_types import (
BetaTextEvent,
BetaCitationEvent,
BetaInputJsonEvent,
BetaMessageStopEvent,
BetaMessageStreamEvent,
Expand Down Expand Up @@ -314,24 +316,40 @@ def build_events(
events_to_fire.append(event)

content_block = message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content_block.type == "text":
events_to_fire.append(
build(
BetaTextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
if event.delta.type == "text_delta":
if content_block.type == "text":
events_to_fire.append(
build(
BetaTextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
)
)
)
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
events_to_fire.append(
build(
BetaInputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
elif event.delta.type == "input_json_delta":
if content_block.type == "tool_use":
events_to_fire.append(
build(
BetaInputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
)
)
)
elif event.delta.type == "citations_delta":
if content_block.type == "text":
events_to_fire.append(
build(
BetaCitationEvent,
type="citation",
citation=event.delta.citation,
snapshot=content_block.citations or [],
)
)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "content_block_stop":
content_block = message_snapshot.content[event.index]

Expand All @@ -354,6 +372,9 @@ def accumulate_event(
event: BetaRawMessageStreamEvent,
current_snapshot: BetaMessage | None,
) -> BetaMessage:
if not isinstance(event, BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
raise TypeError(f"Unexpected event runtime type - {event}")

if current_snapshot is None:
if event.type == "message_start":
return BetaMessage.construct(**cast(Any, event.message.to_dict()))
Expand All @@ -370,21 +391,33 @@ def accumulate_event(
)
elif event.type == "content_block_delta":
content = current_snapshot.content[event.index]
if content.type == "text" and event.delta.type == "text_delta":
content.text += event.delta.text
elif content.type == "tool_use" and event.delta.type == "input_json_delta":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
if event.delta.type == "text_delta":
if content.type == "text":
content.text += event.delta.text
elif event.delta.type == "input_json_delta":
if content.type == "tool_use":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
elif event.delta.type == "citations_delta":
if content.type == "text":
if not content.citations:
content.citations = [event.delta.citation]
else:
content.citations.append(event.delta.citation)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "message_delta":
current_snapshot.stop_reason = event.delta.stop_reason
current_snapshot.stop_sequence = event.delta.stop_sequence
Expand Down
14 changes: 13 additions & 1 deletion src/anthropic/lib/streaming/_beta_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Union
from typing_extensions import Literal, Annotated
from typing_extensions import List, Literal, Annotated

from ..._models import BaseModel
from ...types.beta import (
Expand All @@ -13,6 +13,7 @@
BetaRawContentBlockStartEvent,
)
from ..._utils._transform import PropertyInfo
from ...types.beta.beta_citations_delta import Citation


class BetaTextEvent(BaseModel):
Expand All @@ -25,6 +26,16 @@ class BetaTextEvent(BaseModel):
"""The entire accumulated text"""


class BetaCitationEvent(BaseModel):
type: Literal["citation"]

citation: Citation
"""The new citation"""

snapshot: List[Citation]
"""All of the accumulated citations"""


class BetaInputJsonEvent(BaseModel):
type: Literal["input_json"]

Expand Down Expand Up @@ -57,6 +68,7 @@ class BetaContentBlockStopEvent(BetaRawContentBlockStopEvent):
BetaMessageStreamEvent = Annotated[
Union[
BetaTextEvent,
BetaCitationEvent,
BetaInputJsonEvent,
BetaRawMessageStartEvent,
BetaRawMessageDeltaEvent,
Expand Down
91 changes: 60 additions & 31 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ._types import (
TextEvent,
CitationEvent,
InputJsonEvent,
MessageStopEvent,
MessageStreamEvent,
Expand Down Expand Up @@ -315,24 +316,40 @@ def build_events(
events_to_fire.append(event)

content_block = message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content_block.type == "text":
events_to_fire.append(
build(
TextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
if event.delta.type == "text_delta":
if content_block.type == "text":
events_to_fire.append(
build(
TextEvent,
type="text",
text=event.delta.text,
snapshot=content_block.text,
)
)
)
elif event.delta.type == "input_json_delta" and content_block.type == "tool_use":
events_to_fire.append(
build(
InputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
elif event.delta.type == "input_json_delta":
if content_block.type == "tool_use":
events_to_fire.append(
build(
InputJsonEvent,
type="input_json",
partial_json=event.delta.partial_json,
snapshot=content_block.input,
)
)
)
elif event.delta.type == "citations_delta":
if content_block.type == "text":
events_to_fire.append(
build(
CitationEvent,
type="citation",
citation=event.delta.citation,
snapshot=content_block.citations or [],
)
)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "content_block_stop":
content_block = message_snapshot.content[event.index]

Expand Down Expand Up @@ -374,21 +391,33 @@ def accumulate_event(
)
elif event.type == "content_block_delta":
content = current_snapshot.content[event.index]
if content.type == "text" and event.delta.type == "text_delta":
content.text += event.delta.text
elif content.type == "tool_use" and event.delta.type == "input_json_delta":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
if event.delta.type == "text_delta":
if content.type == "text":
content.text += event.delta.text
elif event.delta.type == "input_json_delta":
if content.type == "tool_use":
from jiter import from_json

# we need to keep track of the raw JSON string as well so that we can
# re-parse it for each delta, for now we just store it as an untyped
# property on the snapshot
json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b""))
json_buf += bytes(event.delta.partial_json, "utf-8")

if json_buf:
content.input = from_json(json_buf, partial_mode=True)

setattr(content, JSON_BUF_PROPERTY, json_buf)
elif event.delta.type == "citations_delta":
if content.type == "text":
if not content.citations:
content.citations = [event.delta.citation]
else:
content.citations.append(event.delta.citation)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event.delta)
elif event.type == "message_delta":
current_snapshot.stop_reason = event.delta.stop_reason
current_snapshot.stop_sequence = event.delta.stop_sequence
Expand Down
14 changes: 13 additions & 1 deletion src/anthropic/lib/streaming/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Union
from typing_extensions import Literal, Annotated
from typing_extensions import List, Literal, Annotated

from ...types import (
Message,
Expand All @@ -13,6 +13,7 @@
)
from ..._models import BaseModel
from ..._utils._transform import PropertyInfo
from ...types.citations_delta import Citation


class TextEvent(BaseModel):
Expand All @@ -25,6 +26,16 @@ class TextEvent(BaseModel):
"""The entire accumulated text"""


class CitationEvent(BaseModel):
type: Literal["citation"]

citation: Citation
"""The new citation"""

snapshot: List[Citation]
"""All of the accumulated citations"""


class InputJsonEvent(BaseModel):
type: Literal["input_json"]

Expand Down Expand Up @@ -57,6 +68,7 @@ class ContentBlockStopEvent(RawContentBlockStopEvent):
MessageStreamEvent = Annotated[
Union[
TextEvent,
CitationEvent,
InputJsonEvent,
RawMessageStartEvent,
RawMessageDeltaEvent,
Expand Down
Loading