Skip to content

Commit

Permalink
feat(weave): Add support for jpegs and pngs (#3304)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Jan 14, 2025
1 parent a3f8faa commit 07e3c40
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 28 deletions.
55 changes: 35 additions & 20 deletions tests/trace/type_handlers/Image/image_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import random
from pathlib import Path

Expand All @@ -20,54 +21,68 @@
"""


def test_image_publish(client: WeaveClient) -> None:
@pytest.fixture(
params=[
{"format": None}, # Default PIL Image
{"format": "JPEG"},
{"format": "PNG"},
]
)
def test_img(request) -> Image.Image:
img = Image.new("RGB", (512, 512), "purple")
weave.publish(img)

ref = get_ref(img)
if (fmt := request.param["format"]) is not None:
buffer = io.BytesIO()
img.save(buffer, format=fmt)
img = Image.open(buffer)

return img


def test_image_publish(client: WeaveClient, test_img: Image.Image) -> None:
weave.publish(test_img)

ref = get_ref(test_img)

assert ref is not None
gotten_img = weave.ref(ref.uri()).get()
assert img.tobytes() == gotten_img.tobytes()
assert test_img.tobytes() == gotten_img.tobytes()


class ImageWrapper(weave.Object):
img: Image.Image


def test_image_as_property(client: WeaveClient) -> None:
def test_image_as_property(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_property"
img = Image.new("RGB", (512, 512), "purple")
img_wrapper = ImageWrapper(img=img)
assert img_wrapper.img == img
img_wrapper = ImageWrapper(img=test_img)
assert img_wrapper.img == test_img

weave.publish(img_wrapper)

ref = get_ref(img_wrapper)
assert ref is not None

gotten_img_wrapper = weave.ref(ref.uri()).get()
assert gotten_img_wrapper.img.tobytes() == img.tobytes()
assert gotten_img_wrapper.img.tobytes() == test_img.tobytes()


def test_image_as_dataset_cell(client: WeaveClient) -> None:
def test_image_as_dataset_cell(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_dataset_cell"
img = Image.new("RGB", (512, 512), "purple")
dataset = weave.Dataset(rows=[{"img": img}])
assert dataset.rows[0]["img"] == img
dataset = weave.Dataset(rows=[{"img": test_img}])
assert dataset.rows[0]["img"] == test_img

weave.publish(dataset)

ref = get_ref(dataset)
assert ref is not None

gotten_dataset = weave.ref(ref.uri()).get()
assert gotten_dataset.rows[0]["img"].tobytes() == img.tobytes()
assert gotten_dataset.rows[0]["img"].tobytes() == test_img.tobytes()


@weave.op
def image_as_solo_output(publish_first: bool) -> Image.Image:
img = Image.new("RGB", (512, 512), "purple")
def image_as_solo_output(publish_first: bool, img: Image.Image) -> Image.Image:
if publish_first:
weave.publish(img)
return img
Expand All @@ -79,9 +94,9 @@ def image_as_input_and_output_part(in_img: Image.Image) -> dict:


@pytest.mark.skip("Flaky in CI with Op loading exception.")
def test_image_as_call_io(client: WeaveClient) -> None:
def test_image_as_call_io(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_call_io"
non_published_img = image_as_solo_output(publish_first=False)
non_published_img = image_as_solo_output(publish_first=False, img=test_img)
img_dict = image_as_input_and_output_part(non_published_img)

exp_bytes = non_published_img.tobytes()
Expand All @@ -95,9 +110,9 @@ def test_image_as_call_io(client: WeaveClient) -> None:
assert image_as_input_and_output_part_call.output["out_img"].tobytes() == exp_bytes


def test_image_as_call_io_refs(client: WeaveClient) -> None:
def test_image_as_call_io_refs(client: WeaveClient, test_img: Image.Image) -> None:
client.project = "test_image_as_call_io_refs"
non_published_img = image_as_solo_output(publish_first=True)
non_published_img = image_as_solo_output(publish_first=True, img=test_img)
img_dict = image_as_input_and_output_part(non_published_img)

exp_bytes = non_published_img.tobytes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ const customWeaveTypeRegistry: {
component: PILImageImage,
preferredRowHeight: 350,
},
'PIL.JpegImagePlugin.JpegImageFile': {
component: PILImageImage,
preferredRowHeight: 350,
},
'PIL.PngImagePlugin.PngImageFile': {
component: PILImageImage,
preferredRowHeight: 350,
},
'wave.Wave_read': {
component: AudioPlayer,
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import React from 'react';

import {LoadingDots} from '../../../../../LoadingDots';
import {NotApplicable} from '../../../Browse2/NotApplicable';
import {useWFHooks} from '../../pages/wfReactInterface/context';
import {CustomWeaveTypePayload} from '../customWeaveType.types';

type PILImageImageTypePayload = CustomWeaveTypePayload<
'PIL.Image.Image',
{'image.png': string}
{'image.jpg': string} | {'image.png': string}
>;

export const isPILImageImageType = (
Expand All @@ -21,20 +22,37 @@ export const PILImageImage: React.FC<{
data: PILImageImageTypePayload;
}> = props => {
const {useFileContent} = useWFHooks();

const imageTypes = {
'image.jpg': 'jpg',
'image.png': 'png',
} as const;

const imageKey = Object.keys(props.data.files).find(
key => key in imageTypes
) as keyof PILImageImageTypePayload['files'] | undefined;
const imageBinary = useFileContent(
props.entity,
props.project,
props.data.files['image.png']
imageKey ? props.data.files[imageKey] : '',
{skip: !imageKey}
);

if (!imageKey) {
return <NotApplicable />;
}
const imageFileExt = imageTypes[imageKey as keyof typeof imageTypes];

if (imageBinary.loading) {
return <LoadingDots />;
} else if (imageBinary.result == null) {
return <span></span>;
}

const arrayBuffer = imageBinary.result as any as ArrayBuffer;
const blob = new Blob([arrayBuffer], {type: 'image/png'});
const blob = new Blob([arrayBuffer], {
type: `image/${imageFileExt}`,
});
const url = URL.createObjectURL(blob);

// TODO: It would be nice to have a more general image render - similar to the
Expand Down
31 changes: 26 additions & 5 deletions weave/type_handlers/Image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from weave.trace import object_preparers, serializer
from weave.trace.custom_objs import MemTraceFilesArtifact
from weave.utils.invertable_dict import InvertableDict

try:
from PIL import Image
Expand All @@ -17,6 +18,16 @@

logger = logging.getLogger(__name__)

DEFAULT_FORMAT = "PNG"

pil_format_to_ext = InvertableDict[str, str](
{
"JPEG": "jpg",
"PNG": "png",
}
)
ext_to_pil_format = pil_format_to_ext.inv


class PILImagePreparer:
def should_prepare(self, obj: Any) -> bool:
Expand All @@ -35,6 +46,12 @@ def prepare(self, obj: Image.Image) -> None:


def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
fmt = getattr(obj, "format", DEFAULT_FORMAT)
ext = pil_format_to_ext.get(fmt)
if ext is None:
logger.warning(f"Unknown image format {fmt}, defaulting to {DEFAULT_FORMAT}")
ext = pil_format_to_ext[DEFAULT_FORMAT]

# Note: I am purposely ignoring the `name` here and hard-coding the filename to "image.png".
# There is an extensive internal discussion here:
# https://weightsandbiases.slack.com/archives/C03BSTEBD7F/p1723670081582949
Expand All @@ -49,14 +66,18 @@ def save(obj: Image.Image, artifact: MemTraceFilesArtifact, name: str) -> None:
# using the same artifact. Moreover, since we package the deserialization logic with the
# object payload, we can always change the serialization logic later without breaking
# existing payloads.
with artifact.new_file("image.png", binary=True) as f:
obj.save(f, format="png") # type: ignore
fname = f"image.{ext}"
with artifact.new_file(fname, binary=True) as f:
obj.save(f, format=ext_to_pil_format[ext]) # type: ignore


def load(artifact: MemTraceFilesArtifact, name: str) -> Image.Image:
# Note: I am purposely ignoring the `name` here and hard-coding the filename. See comment
# on save.
path = artifact.path("image.png")
# Today, we assume there can only be 1 image in the artifact.
filename = next(iter(artifact.path_contents))
if not filename.startswith("image."):
raise ValueError(f"Expected filename to start with 'image.', got {filename}")

path = artifact.path(filename)
return Image.open(path)


Expand Down
56 changes: 56 additions & 0 deletions weave/utils/invertable_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from collections.abc import Iterator, MutableMapping
from typing import Any, TypeVar

KT = TypeVar("KT")
VT = TypeVar("VT")


class InvertableDict(MutableMapping[KT, VT]):
"""A bijective mapping that behaves like a dict.
Invert the dict using the `inv` property.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
self._forward = dict(*args, **kwargs)
self._backward: dict[VT, KT] = {}
for key, value in self._forward.items():
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._backward[value] = key

def __getitem__(self, key: KT) -> VT:
return self._forward[key]

def __setitem__(self, key: KT, value: VT) -> None:
if key in self._forward:
del self._backward[self._forward[key]]
if value in self._backward:
raise ValueError(f"Duplicate value found: {value}")
self._forward[key] = value
self._backward[value] = key

def __delitem__(self, key: KT) -> None:
value = self._forward.pop(key)
del self._backward[value]

def __iter__(self) -> Iterator[KT]:
return iter(self._forward)

def __len__(self) -> int:
return len(self._forward)

def __repr__(self) -> str:
return repr(self._forward)

def __contains__(self, key: Any) -> bool:
return key in self._forward

@property
def inv(self) -> InvertableDict[VT, KT]:
res = InvertableDict[VT, KT]()
res._forward = self._backward
res._backward = self._forward
return res

0 comments on commit 07e3c40

Please sign in to comment.