Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(weave): Use sdk-local deserializer instead of saved deserializer for known types like Images #2696

Merged
merged 13 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/trace/test_custom_objs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from PIL import Image

from weave.trace.custom_objs import decode_custom_obj, encode_custom_obj


def test_decode_custom_obj_known_type(client):
img = Image.new("RGB", (100, 100))
encoded = encode_custom_obj(img)

# Even though something is wrong with the deserializer op, we can still decode
decoded = decode_custom_obj(
encoded["weave_type"], encoded["files"], "weave:///totally/invalid/uri"
)

assert isinstance(decoded, Image.Image)
assert decoded.tobytes() == img.tobytes()
40 changes: 24 additions & 16 deletions weave/trace/custom_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from weave.trace.refs import ObjectRef, parse_uri
from weave.trace.serializer import get_serializer_by_id, get_serializer_for_obj

# in future, could generalize as
# {target_cls.__module__}.{target_cls.__qualname__}
KNOWN_TYPES = ["PIL.Image.Image", "wave.Wave_read"]


def encode_custom_obj(obj: Any) -> Optional[dict]:
serializer = get_serializer_for_obj(obj)
Expand Down Expand Up @@ -52,31 +56,35 @@ def decode_custom_obj(
encoded_path_contents: Mapping[str, Union[str, bytes]],
load_instance_op_uri: Optional[str],
) -> Any:
load_instance_op = None
if load_instance_op_uri is not None:
_type = weave_type["type"]
found_serializer = False

# First, try to load the object using a known serializer
if _type in KNOWN_TYPES:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is still not quite right - we need to try with the local one first, else fallback to the the remote one

serializer = get_serializer_by_id(_type)
if serializer is not None:
found_serializer = True
load_instance_op = serializer.load

# Otherwise, fall back to load_instance_op
if not found_serializer:
if load_instance_op_uri is None:
raise ValueError(f"No serializer found for `{_type}`")

ref = parse_uri(load_instance_op_uri)
if not isinstance(ref, ObjectRef):
raise ValueError(f"Expected ObjectRef, got {load_instance_op_uri}")
wc = require_weave_client()
load_instance_op = wc.get(ref)
raise TypeError(f"Expected ObjectRef, got `{type(ref)}`")

load_instance_op = ref.get()
if load_instance_op is None:
raise ValueError(
f"Failed to load op needed to decode object of type {weave_type}. See logs above for more information."
f"Failed to load op needed to decode object of type `{_type}`. See logs above for more information."
)

if load_instance_op is None:
serializer = get_serializer_by_id(weave_type["type"])
if serializer is None:
raise ValueError(f"No serializer found for {weave_type}")
load_instance_op = serializer.load

# Disables tracing so that calls to loading data itself don't get traced
load_instance_op._tracing_enabled = False # type: ignore

art = MemTraceFilesArtifact(
encoded_path_contents,
metadata={},
)
art = MemTraceFilesArtifact(encoded_path_contents, metadata={})
res = load_instance_op(art, "obj")
res.art = art
return res
Loading