From dfd4cb936c70b0b00cf10dca5dbe2b927fb9e218 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Mon, 12 Sep 2022 17:03:50 -0700 Subject: [PATCH] Pickle Image (#160) --- python/lance/tests/test_types.py | 32 ++++++++++++++++++++++++++++---- python/lance/types/__init__.py | 13 ++++++++++--- python/lance/types/image.py | 10 +++++----- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/python/lance/tests/test_types.py b/python/lance/tests/test_types.py index 012271b65c..1a32661b19 100644 --- a/python/lance/tests/test_types.py +++ b/python/lance/tests/test_types.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import pickle import platform import numpy as np @@ -19,9 +19,17 @@ import pytest import lance -from lance.types import Box2dType, ImageType, LabelType, Point2dType -from lance.types.box import Box2dArray -from lance.types.label import LabelArray +from lance.types import ( + Box2dArray, + Box2dType, + Image, + ImageBinary, + ImageType, + ImageUri, + LabelArray, + LabelType, + Point2dType, +) if platform.system() != "Linux": pytest.skip(allow_module_level=True) @@ -123,3 +131,19 @@ def _test_extension_rt(tmp_path, ext_type, storage_arr): assert table["ext"].type == ext_type assert table["ext"].to_pylist() == arr.to_pylist() return table["ext"] + + +def test_pickle(tmp_path): + img = Image.create("uri") + assert isinstance(img, ImageUri) + with (tmp_path / "image").open("wb") as fh: + pickle.dump(img, fh) + with (tmp_path / "image").open("rb") as fh: + assert img == pickle.load(fh) + + img = Image.create(b"bytes") + assert isinstance(img, ImageBinary) + with (tmp_path / "image").open("wb") as fh: + pickle.dump(img, fh) + with (tmp_path / "image").open("rb") as fh: + assert img == pickle.load(fh) diff --git a/python/lance/types/__init__.py b/python/lance/types/__init__.py index 568978c95e..e090ed81e2 100644 --- a/python/lance/types/__init__.py +++ b/python/lance/types/__init__.py @@ -18,9 +18,16 @@ from pyarrow import ArrowKeyError from lance.types.base import Point2dType -from lance.types.box import Box2dType -from lance.types.image import ImageBinaryType, ImageType, ImageUriType -from lance.types.label import LabelType +from lance.types.box import Box2dArray, Box2dType +from lance.types.image import ( + Image, + ImageBinary, + ImageBinaryType, + ImageType, + ImageUri, + ImageUriType, +) +from lance.types.label import LabelArray, LabelType def register_extension_types(): diff --git a/python/lance/types/image.py b/python/lance/types/image.py index 2bada633d3..4284b509c6 100644 --- a/python/lance/types/image.py +++ b/python/lance/types/image.py @@ -91,16 +91,16 @@ class Image(ABC): representations """ - def __new__(cls, data: Union[bytes, str]): + @staticmethod + def create(data: Union[bytes, str]): if isinstance(data, bytes): - img = object.__new__(ImageBinary) + img = ImageBinary(data) elif isinstance(data, str): - img = object.__new__(ImageUri) + img = ImageUri(data) else: raise TypeError( - f"{cls.__name__} can only handle bytes or str " f"but got {type(data)}" + f"Image can only handle bytes or str " f"but got {type(data)}" ) - img.__init__(data) return img @classmethod