Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Add support for jpegs and pngs #3304

Merged
merged 15 commits into from
Jan 14, 2025
55 changes: 35 additions & 20 deletions tests/trace/type_handlers/Image/image_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import random
from pathlib import Path

Expand All @@ -20,54 +21,68 @@
"""


def test_image_publish(client: WeaveClient) -> None:
@pytest.fixture(
params=[
{"format": None}, # Default PIL Image
{"format": "JPEG"},
{"format": "PNG"},
]
)
def test_img(request) -> Image.Image:
img = Image.new("RGB", (512, 512), "purple")
weave.publish(img)

ref = get_ref(img)
if (fmt := request.param["format"]) is not None:
buffer = io.BytesIO()
img.save(buffer, format=fmt)
img = Image.open(buffer)

return img


def test_image_publish(client: WeaveClient, test_img: Image.Image) -> None:
weave.publish(test_img)

ref = get_ref(test_img)

assert ref is not None
gotten_img = weave.ref(ref.uri()).get()
assert img.tobytes() == gotten_img.tobytes()
assert test_img.tobytes() == gotten_img.tobytes()


class ImageWrapper(weave.Object):
img: Image.Image


def test_image_as_property(client: WeaveClient) -> None:
def test_image_as_property(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_property"
img = Image.new("RGB", (512, 512), "purple")
img_wrapper = ImageWrapper(img=img)
assert img_wrapper.img == img
img_wrapper = ImageWrapper(img=test_img)
assert img_wrapper.img == test_img

weave.publish(img_wrapper)

ref = get_ref(img_wrapper)
assert ref is not None

gotten_img_wrapper = weave.ref(ref.uri()).get()
assert gotten_img_wrapper.img.tobytes() == img.tobytes()
assert gotten_img_wrapper.img.tobytes() == test_img.tobytes()


def test_image_as_dataset_cell(client: WeaveClient) -> None:
def test_image_as_dataset_cell(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_dataset_cell"
img = Image.new("RGB", (512, 512), "purple")
dataset = weave.Dataset(rows=[{"img": img}])
assert dataset.rows[0]["img"] == img
dataset = weave.Dataset(rows=[{"img": test_img}])
assert dataset.rows[0]["img"] == test_img

weave.publish(dataset)

ref = get_ref(dataset)
assert ref is not None

gotten_dataset = weave.ref(ref.uri()).get()
assert gotten_dataset.rows[0]["img"].tobytes() == img.tobytes()
assert gotten_dataset.rows[0]["img"].tobytes() == test_img.tobytes()


@weave.op
def image_as_solo_output(publish_first: bool) -> Image.Image:
img = Image.new("RGB", (512, 512), "purple")
def image_as_solo_output(publish_first: bool, img: Image.Image) -> Image.Image:
if publish_first:
weave.publish(img)
return img
Expand All @@ -79,9 +94,9 @@ def image_as_input_and_output_part(in_img: Image.Image) -> dict:


@pytest.mark.skip("Flaky in CI with Op loading exception.")
def test_image_as_call_io(client: WeaveClient) -> None:
def test_image_as_call_io(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_call_io"
non_published_img = image_as_solo_output(publish_first=False)
non_published_img = image_as_solo_output(publish_first=False, img=test_img)
img_dict = image_as_input_and_output_part(non_published_img)

exp_bytes = non_published_img.tobytes()
Expand All @@ -95,9 +110,9 @@ def test_image_as_call_io(client: WeaveClient) -> None:
assert image_as_input_and_output_part_call.output["out_img"].tobytes() == exp_bytes


def test_image_as_call_io_refs(client: WeaveClient) -> None:
def test_image_as_call_io_refs(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_call_io_refs"
non_published_img = image_as_solo_output(publish_first=True)
non_published_img = image_as_solo_output(publish_first=True, img=test_img)
img_dict = image_as_input_and_output_part(non_published_img)

exp_bytes = non_published_img.tobytes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ const customWeaveTypeRegistry: {
component: PILImageImage,
preferredRowHeight: 350,
},
'PIL.JpegImagePlugin.JpegImageFile': {
component: PILImageImage,
preferredRowHeight: 350,
},
'PIL.PngImagePlugin.PngImageFile': {
component: PILImageImage,
preferredRowHeight: 350,
},
'wave.Wave_read': {
component: AudioPlayer,
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import React from 'react';

import {LoadingDots} from '../../../../../LoadingDots';
import {NotApplicable} from '../../../Browse2/NotApplicable';
import {useWFHooks} from '../../pages/wfReactInterface/context';
import {CustomWeaveTypePayload} from '../customWeaveType.types';

type PILImageImageTypePayload = CustomWeaveTypePayload<
'PIL.Image.Image',
{'image.png': string}
{'image.jpg': string} | {'image.png': string}
>;

export const isPILImageImageType = (
Expand All @@ -21,20 +22,36 @@ export const PILImageImage: React.FC<{
data: PILImageImageTypePayload;
}> = props => {
const {useFileContent} = useWFHooks();

const imageTypes = {
'image.jpg': 'jpg',
'image.png': 'png',
} as const;

const imageKey = Object.keys(props.data.files).find(
key => key in imageTypes
) as keyof PILImageImageTypePayload['files'];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still not correct since you are not guaranteed to find the key. This cast is invalid and causes line 37 to be incorrect.

const imageBinary = useFileContent(
props.entity,
props.project,
props.data.files['image.png']
props.data.files[imageKey]
);

if (!imageKey) {
return <NotApplicable />;
}
const imageFileExt = imageTypes[imageKey as keyof typeof imageTypes];

if (imageBinary.loading) {
return <LoadingDots />;
} else if (imageBinary.result == null) {
return <span></span>;
}

const arrayBuffer = imageBinary.result as any as ArrayBuffer;
const blob = new Blob([arrayBuffer], {type: 'image/png'});
const blob = new Blob([arrayBuffer], {
type: `image/${imageFileExt}`,
});
const url = URL.createObjectURL(blob);

// TODO: It would be nice to have a more general image render - similar to the
Expand Down
31 changes: 26 additions & 5 deletions weave/type_handlers/Image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from weave.trace import object_preparers, serializer
from weave.trace.custom_objs import MemTraceFilesArtifact
from weave.utils.invertable_dict import InvertableDict

try:
from PIL import Image
Expand All @@ -17,6 +18,16 @@

logger = logging.getLogger(__name__)

DEFAULT_FORMAT = "PNG"

pil_format_to_ext = InvertableDict[str, str](
{
"JPEG": "jpg",
"PNG": "png",
}
)
ext_to_pil_format = pil_format_to_ext.inv


class PILImagePreparer:
def should_prepare(self, obj: Any) -> bool:
Expand All @@ -35,6 +46,12 @@ def prepare(self, obj: Image.Image) -> None:


def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
fmt = getattr(obj, "format", DEFAULT_FORMAT)
ext = pil_format_to_ext.get(fmt)
if ext is None:
logger.warning(f"Unknown image format {fmt}, defaulting to {DEFAULT_FORMAT}")
ext = pil_format_to_ext[DEFAULT_FORMAT]

# Note: I am purposely ignoring the `name` here and hard-coding the filename to "image.png".
# There is an extensive internal discussion here:
# https://weightsandbiases.slack.com/archives/C03BSTEBD7F/p1723670081582949
Expand All @@ -49,14 +66,18 @@ def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
# using the same artifact. Moreover, since we package the deserialization logic with the
# object payload, we can always change the serialization logic later without breaking
# existing payloads.
with artifact.new_file("image.png", binary=True) as f:
obj.save(f, format="png") # type: ignore
fname = f"image.{ext}"
with artifact.new_file(fname, binary=True) as f:
obj.save(f, format=ext_to_pil_format[ext]) # type: ignore


def load(artifact: MemTraceFilesArtifact, name: str) -> Image.Image:
# Note: I am purposely ignoring the `name` here and hard-coding the filename. See comment
# on save.
path = artifact.path("image.png")
# Today, we assume there can only be 1 image in the artifact.
filename = next(iter(artifact.path_contents))
if not filename.startswith("image."):
raise ValueError(f"Expected filename to start with 'image.', got {filename}")

path = artifact.path(filename)
return Image.open(path)


Expand Down
56 changes: 56 additions & 0 deletions weave/utils/invertable_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from collections.abc import Iterator, MutableMapping
from typing import Any, TypeVar

KT = TypeVar("KT")
VT = TypeVar("VT")


class InvertableDict(MutableMapping[KT, VT]):
"""A bijective mapping that behaves like a dict.

Invert the dict using the `inv` property.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
self._forward = dict(*args, **kwargs)
self._backward: dict[VT, KT] = {}
for key, value in self._forward.items():
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._backward[value] = key

def __getitem__(self, key: KT) -> VT:
return self._forward[key]

def __setitem__(self, key: KT, value: VT) -> None:
if key in self._forward:
del self._backward[self._forward[key]]
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._forward[key] = value
self._backward[value] = key

def __delitem__(self, key: KT) -> None:
value = self._forward.pop(key)
del self._backward[value]

def __iter__(self) -> Iterator[KT]:
return iter(self._forward)

def __len__(self) -> int:
return len(self._forward)

def __repr__(self) -> str:
return repr(self._forward)

def __contains__(self, key: Any) -> bool:
return key in self._forward

@property
def inv(self) -> InvertableDict[VT, KT]:
res = InvertableDict[VT, KT]()
res._forward = self._backward
res._backward = self._forward
return res
Loading