Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 21, 2024
1 parent 1b77d8f commit e24aef5
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 7 deletions.
48 changes: 47 additions & 1 deletion weave/trace/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, MutableMapping
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor
from contextvars import Context, copy_context
from functools import partial, wraps
Expand Down Expand Up @@ -168,6 +168,52 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return deco


class InvertableDict(MutableMapping):
"""A bijective mapping that behaves like a dict.
Invert the dict using the `inv` property.
"""

def __init__(self, *args, **kwargs):
self._forward = dict(*args, **kwargs)
self._backward = {}
for key, value in self._forward.items():
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._backward[value] = key

def __getitem__(self, key):
return self._forward[key]

def __setitem__(self, key, value):
if key in self._forward:
del self._backward[self._forward[key]]
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._forward[key] = value
self._backward[value] = key

def __delitem__(self, key):
value = self._forward.pop(key)
del self._backward[value]

def __iter__(self):
return iter(self._forward)

def __len__(self):
return len(self._forward)

def __repr__(self):
return repr(self._forward)

def __contains__(self, key):
return key in self._forward

@property
def inv(self):
return self._backward


# rename for cleaner export
ThreadPoolExecutor = ContextAwareThreadPoolExecutor
Thread = ContextAwareThread
Expand Down
34 changes: 28 additions & 6 deletions weave/type_handlers/Image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from weave.trace import object_preparers, serializer
from weave.trace.custom_objs import MemTraceFilesArtifact
from weave.trace.util import InvertableDict

try:
from PIL import Image
Expand All @@ -17,6 +18,14 @@

logger = logging.getLogger(__name__)

DEFAULT_FORMAT = "PNG"
PIL_FORMAT_TO_EXT = InvertableDict(
{
"JPEG": "jpg",
"PNG": "png",
}
)


class PILImagePreparer:
def should_prepare(self, obj: Any) -> bool:
Expand All @@ -35,6 +44,12 @@ def prepare(self, obj: Image.Image) -> None:


def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
fmt = getattr(obj, "format", DEFAULT_FORMAT)
ext = PIL_FORMAT_TO_EXT.get(fmt)
if ext is None:
logger.warning(f"Unknown image format {fmt}, defaulting to {DEFAULT_FORMAT}")
ext = PIL_FORMAT_TO_EXT[DEFAULT_FORMAT]

# Note: I am purposely ignoring the `name` here and hard-coding the filename to "image.png".
# There is an extensive internal discussion here:
# https://weightsandbiases.slack.com/archives/C03BSTEBD7F/p1723670081582949
Expand All @@ -49,15 +64,22 @@ def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
# using the same artifact. Moreover, since we package the deserialization logic with the
# object payload, we can always change the serialization logic later without breaking
# existing payloads.
with artifact.new_file("image.png", binary=True) as f:
obj.save(f, format="png") # type: ignore
fname = f"image.{ext}"
with artifact.new_file(fname, binary=True) as f:
obj.save(f, format=PIL_FORMAT_TO_EXT.inv[ext]) # type: ignore


def load(artifact: MemTraceFilesArtifact, name: str) -> Image.Image:
# Note: I am purposely ignoring the `name` here and hard-coding the filename. See comment
# on save.
path = artifact.path("image.png")
return Image.open(path)
for ext in PIL_FORMAT_TO_EXT.values():
# Note: I am purposely ignoring the `name` here and hard-coding the filename.
# See comment on save.
fname = f"image.{ext}"
path = artifact.path(fname)
try:
return Image.open(path)
except FileNotFoundError:
continue
raise FileNotFoundError(f"No image found in artifact {artifact}")


def register() -> None:
Expand Down

0 comments on commit e24aef5

Please sign in to comment.