Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Add support for jpegs and pngs #3304

Merged
merged 15 commits into from
Jan 14, 2025
55 changes: 35 additions & 20 deletions tests/trace/type_handlers/Image/image_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import random
from pathlib import Path

Expand All @@ -20,54 +21,68 @@
"""


def test_image_publish(client: WeaveClient) -> None:
@pytest.fixture(
params=[
{"format": None}, # Default PIL Image
{"format": "JPEG"},
{"format": "PNG"},
]
)
def test_img(request) -> Image.Image:
img = Image.new("RGB", (512, 512), "purple")
weave.publish(img)

ref = get_ref(img)
if (fmt := request.param["format"]) is not None:
buffer = io.BytesIO()
img.save(buffer, format=fmt)
img = Image.open(buffer)

return img


def test_image_publish(client: WeaveClient, test_img: Image.Image) -> None:
weave.publish(test_img)

ref = get_ref(test_img)

assert ref is not None
gotten_img = weave.ref(ref.uri()).get()
assert img.tobytes() == gotten_img.tobytes()
assert test_img.tobytes() == gotten_img.tobytes()


class ImageWrapper(weave.Object):
img: Image.Image


def test_image_as_property(client: WeaveClient) -> None:
def test_image_as_property(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_property"
img = Image.new("RGB", (512, 512), "purple")
img_wrapper = ImageWrapper(img=img)
assert img_wrapper.img == img
img_wrapper = ImageWrapper(img=test_img)
assert img_wrapper.img == test_img

weave.publish(img_wrapper)

ref = get_ref(img_wrapper)
assert ref is not None

gotten_img_wrapper = weave.ref(ref.uri()).get()
assert gotten_img_wrapper.img.tobytes() == img.tobytes()
assert gotten_img_wrapper.img.tobytes() == test_img.tobytes()


def test_image_as_dataset_cell(client: WeaveClient) -> None:
def test_image_as_dataset_cell(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_dataset_cell"
img = Image.new("RGB", (512, 512), "purple")
dataset = weave.Dataset(rows=[{"img": img}])
assert dataset.rows[0]["img"] == img
dataset = weave.Dataset(rows=[{"img": test_img}])
assert dataset.rows[0]["img"] == test_img

weave.publish(dataset)

ref = get_ref(dataset)
assert ref is not None

gotten_dataset = weave.ref(ref.uri()).get()
assert gotten_dataset.rows[0]["img"].tobytes() == img.tobytes()
assert gotten_dataset.rows[0]["img"].tobytes() == test_img.tobytes()


@weave.op
def image_as_solo_output(publish_first: bool) -> Image.Image:
img = Image.new("RGB", (512, 512), "purple")
def image_as_solo_output(publish_first: bool, img: Image.Image) -> Image.Image:
if publish_first:
weave.publish(img)
return img
Expand All @@ -79,9 +94,9 @@ def image_as_input_and_output_part(in_img: Image.Image) -> dict:


@pytest.mark.skip("Flaky in CI with Op loading exception.")
def test_image_as_call_io(client: WeaveClient) -> None:
def test_image_as_call_io(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_call_io"
non_published_img = image_as_solo_output(publish_first=False)
non_published_img = image_as_solo_output(publish_first=False, img=test_img)
img_dict = image_as_input_and_output_part(non_published_img)

exp_bytes = non_published_img.tobytes()
Expand All @@ -95,9 +110,9 @@ def test_image_as_call_io(client: WeaveClient) -> None:
assert image_as_input_and_output_part_call.output["out_img"].tobytes() == exp_bytes


def test_image_as_call_io_refs(client: WeaveClient) -> None:
def test_image_as_call_io_refs(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_call_io_refs"
non_published_img = image_as_solo_output(publish_first=True)
non_published_img = image_as_solo_output(publish_first=True, img=test_img)
img_dict = image_as_input_and_output_part(non_published_img)

exp_bytes = non_published_img.tobytes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ const customWeaveTypeRegistry: {
component: PILImageImage,
preferredRowHeight: 350,
},
'PIL.JpegImagePlugin.JpegImageFile': {
component: PILImageImage,
preferredRowHeight: 350,
},
'PIL.PngImagePlugin.PngImageFile': {
component: PILImageImage,
preferredRowHeight: 350,
},
'wave.Wave_read': {
component: AudioPlayer,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import {CustomWeaveTypePayload} from '../customWeaveType.types';

type PILImageImageTypePayload = CustomWeaveTypePayload<
'PIL.Image.Image',
{'image.png': string}
{'image.jpg'?: string; 'image.png'?: string}
andrewtruong marked this conversation as resolved.
Show resolved Hide resolved
>;

const DEFAULT_IMAGE_FILE_EXT = 'png';

export const isPILImageImageType = (
data: CustomWeaveTypePayload
): data is PILImageImageTypePayload => {
Expand All @@ -21,10 +23,17 @@ export const PILImageImage: React.FC<{
data: PILImageImageTypePayload;
}> = props => {
const {useFileContent} = useWFHooks();

const imageFileName = Object.keys(props.data.files)[0] as
andrewtruong marked this conversation as resolved.
Show resolved Hide resolved
| 'image.jpg'
| 'image.png';
const imageFileExt =
imageFileName.split('.').pop()?.toLowerCase() || DEFAULT_IMAGE_FILE_EXT;

const imageBinary = useFileContent(
props.entity,
props.project,
props.data.files['image.png']
props.data.files[imageFileName]!
);

if (imageBinary.loading) {
Expand All @@ -34,7 +43,9 @@ export const PILImageImage: React.FC<{
}

const arrayBuffer = imageBinary.result as any as ArrayBuffer;
const blob = new Blob([arrayBuffer], {type: 'image/png'});
const blob = new Blob([arrayBuffer], {
type: `image/${imageFileExt}`,
});
const url = URL.createObjectURL(blob);

// TODO: It would be nice to have a more general image render - similar to the
Expand Down
57 changes: 55 additions & 2 deletions weave/trace/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import warnings
from collections.abc import Iterable, Iterator
from collections.abc import Iterable, Iterator, MutableMapping
from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor
from contextvars import Context, copy_context
from functools import partial, wraps
from threading import Thread as _Thread
from typing import Any, Callable
from typing import Any, Callable, TypeVar

LOG_ONCE_MESSAGE_SUFFIX = " (subsequent messages of this type will be suppressed)"
logged_messages = []
Expand Down Expand Up @@ -168,6 +168,59 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return deco


KT = TypeVar("KT")
VT = TypeVar("VT")


class InvertableDict(MutableMapping[KT, VT]):
andrewtruong marked this conversation as resolved.
Show resolved Hide resolved
"""A bijective mapping that behaves like a dict.

Invert the dict using the `inv` property.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
self._forward = dict(*args, **kwargs)
self._backward: dict[VT, KT] = {}
for key, value in self._forward.items():
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._backward[value] = key

def __getitem__(self, key: KT) -> VT:
return self._forward[key]

def __setitem__(self, key: KT, value: VT) -> None:
if key in self._forward:
del self._backward[self._forward[key]]
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._forward[key] = value
self._backward[value] = key

def __delitem__(self, key: KT) -> None:
value = self._forward.pop(key)
del self._backward[value]

def __iter__(self) -> Iterator[KT]:
return iter(self._forward)

def __len__(self) -> int:
return len(self._forward)

def __repr__(self) -> str:
return repr(self._forward)

def __contains__(self, key: Any) -> bool:
return key in self._forward

@property
def inv(self) -> InvertableDict[VT, KT]:
res = InvertableDict[VT, KT]()
res._forward = self._backward
res._backward = self._forward
return res


# rename for cleaner export
ThreadPoolExecutor = ContextAwareThreadPoolExecutor
Thread = ContextAwareThread
Expand Down
35 changes: 29 additions & 6 deletions weave/type_handlers/Image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from weave.trace import object_preparers, serializer
from weave.trace.custom_objs import MemTraceFilesArtifact
from weave.trace.util import InvertableDict

try:
from PIL import Image
Expand All @@ -17,6 +18,16 @@

logger = logging.getLogger(__name__)

DEFAULT_FORMAT = "PNG"

pil_format_to_ext = InvertableDict[str, str](
{
"JPEG": "jpg",
"PNG": "png",
}
)
ext_to_pil_format = pil_format_to_ext.inv


class PILImagePreparer:
def should_prepare(self, obj: Any) -> bool:
Expand All @@ -35,6 +46,12 @@ def prepare(self, obj: Image.Image) -> None:


def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
fmt = getattr(obj, "format", DEFAULT_FORMAT)
ext = pil_format_to_ext.get(fmt)
if ext is None:
logger.warning(f"Unknown image format {fmt}, defaulting to {DEFAULT_FORMAT}")
ext = pil_format_to_ext[DEFAULT_FORMAT]

# Note: I am purposely ignoring the `name` here and hard-coding the filename to "image.png".
# There is an extensive internal discussion here:
# https://weightsandbiases.slack.com/archives/C03BSTEBD7F/p1723670081582949
Expand All @@ -49,15 +66,21 @@ def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
# using the same artifact. Moreover, since we package the deserialization logic with the
# object payload, we can always change the serialization logic later without breaking
# existing payloads.
with artifact.new_file("image.png", binary=True) as f:
obj.save(f, format="png") # type: ignore
fname = f"image.{ext}"
with artifact.new_file(fname, binary=True) as f:
obj.save(f, format=ext_to_pil_format[ext]) # type: ignore


def load(artifact: MemTraceFilesArtifact, name: str) -> Image.Image:
# Note: I am purposely ignoring the `name` here and hard-coding the filename. See comment
# on save.
path = artifact.path("image.png")
return Image.open(path)
for ext in pil_format_to_ext.values():
andrewtruong marked this conversation as resolved.
Show resolved Hide resolved
# Note: I am purposely ignoring the `name` here and hard-coding the filename.
# See comment on save.
try:
path = artifact.path(f"image.{ext}")
except FileNotFoundError:
continue
return Image.open(path)
raise FileNotFoundError(f"No image found in artifact {artifact}")


def register() -> None:
Expand Down
Loading