diff --git a/python/lance/__init__.py b/python/lance/__init__.py index a0fb9f3e15..5ace46d79b 100644 --- a/python/lance/__init__.py +++ b/python/lance/__init__.py @@ -12,13 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import platform from pathlib import Path -from typing import Union, Optional +from typing import Optional, Union import pyarrow as pa import pyarrow.compute as pc import pyarrow.dataset as ds -from lance.lib import LanceFileFormat, WriteTable, BuildScanner + +from lance.lib import BuildScanner, LanceFileFormat, WriteTable +from lance.types import register_extension_types + +if platform.system() == "Linux": + # TODO enable on MacOS + register_extension_types() __all__ = ["dataset", "write_table", "scanner"] diff --git a/python/lance/_lib.pyx b/python/lance/_lib.pyx index edd8c7846b..d1dd514ce6 100644 --- a/python/lance/_lib.pyx +++ b/python/lance/_lib.pyx @@ -1,20 +1,33 @@ # distutils: language = c++ -from typing import Union, Optional +from typing import Optional, Union from cython.operator cimport dereference as deref from libcpp cimport bool from libcpp.memory cimport shared_ptr from libcpp.string cimport string + from pathlib import Path + from pyarrow import Table -from pyarrow._dataset cimport FileFormat, CScanner, CDataset, Dataset + +from pyarrow._dataset cimport CDataset, CScanner, Dataset, FileFormat + from pyarrow._dataset import Scanner -from pyarrow.includes.common cimport * + from pyarrow._compute cimport Expression, _bind -from pyarrow.includes.libarrow cimport CTable, COutputStream +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport COutputStream, CTable from pyarrow.includes.libarrow_dataset cimport CFileFormat -from pyarrow.lib cimport GetResultValue, check_status, pyarrow_unwrap_table, get_writer, RecordBatchReader, CExpression +from pyarrow.lib cimport ( + CExpression, + GetResultValue, + RecordBatchReader, + check_status, + get_writer, + pyarrow_unwrap_table, +) + from pyarrow.lib import tobytes diff --git a/python/lance/tests/api_test.py b/python/lance/tests/test_api.py similarity index 85% rename from python/lance/tests/api_test.py rename to python/lance/tests/test_api.py index e19cdf979d..ea783f0d7f 100644 --- a/python/lance/tests/api_test.py +++ b/python/lance/tests/test_api.py @@ -17,11 +17,14 @@ import pandas as pd import pyarrow as pa -from lance import write_table, dataset + +from lance import dataset, write_table def test_simple_round_trips(tmp_path: Path): - table = pa.Table.from_pandas(pd.DataFrame({"label": [123, 456, 789], "values": [22, 33, 2.24]})) + table = pa.Table.from_pandas( + pd.DataFrame({"label": [123, 456, 789], "values": [22, 33, 2.24]}) + ) write_table(table, tmp_path / "test.lance") assert (tmp_path / "test.lance").exists() @@ -29,7 +32,7 @@ def test_simple_round_trips(tmp_path: Path): ds = dataset(str(tmp_path / "test.lance")) actual = ds.to_table() - assert (table == actual) + assert table == actual def test_write_categorical_values(tmp_path: Path): @@ -41,4 +44,4 @@ def test_write_categorical_values(tmp_path: Path): assert (tmp_path / "test.lance").exists() actual = dataset(str(tmp_path / "test.lance")).to_table() - assert (table == actual) + assert table == actual diff --git a/python/lance/tests/duckdb_test.py b/python/lance/tests/test_duckdb.py similarity index 91% rename from python/lance/tests/duckdb_test.py rename to python/lance/tests/test_duckdb.py index 9955e75cfd..5f9840d13c 100644 --- a/python/lance/tests/duckdb_test.py +++ b/python/lance/tests/test_duckdb.py @@ -18,13 +18,16 @@ from pathlib import Path import duckdb -import lance import pandas as pd import pyarrow as pa +import lance + def test_dictionary_type_query(tmp_path: Path): - df = pd.DataFrame({"class": ["foo", "bar", "foo", "zoo"], "grade": ["A", "B", "B", "A"]}) + df = pd.DataFrame( + {"class": ["foo", "bar", "foo", "zoo"], "grade": ["A", "B", "B", "A"]} + ) # df["class"] = df["class"].astype("category") # df["grade"] = df["grade"].astype("category") diff --git a/python/lance/tests/test_types.py b/python/lance/tests/test_types.py new file mode 100644 index 0000000000..5c4d59fc26 --- /dev/null +++ b/python/lance/tests/test_types.py @@ -0,0 +1,61 @@ +# Copyright 2022 Lance Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +import pyarrow as pa +import pytest + +import lance +from lance.types import * + +if platform.system() != "Linux": + pytest.skip(allow_module_level=True) + + +def test_image(tmp_path): + data = [f"s3://bucket/{x}.jpg" for x in ["a", "b", "c"]] + storage = pa.StringArray.from_pandas(data) + image_type = ImageType.from_storage(storage.type) + _test_extension_rt(tmp_path, image_type, storage) + + +def test_image_binary(tmp_path): + data = [b"" for x in ["a", "b", "c"]] + storage = pa.StringArray.from_pandas(data) + image_type = ImageType.from_storage(storage.type) + _test_extension_rt(tmp_path, image_type, storage) + + +def test_point(tmp_path): + point_type = Point2dType() + data = [(float(x), float(x)) for x in range(100)] + storage = pa.array(data, pa.list_(pa.float64())) + _test_extension_rt(tmp_path, point_type, storage) + + +def test_box2d(tmp_path): + box_type = Box2dType() + data = [(float(x), float(x), float(x), float(x)) for x in range(100)] + storage = pa.array(data, pa.list_(pa.float64())) + _test_extension_rt(tmp_path, box_type, storage) + + +def _test_extension_rt(tmp_path, ext_type, storage_arr): + arr = pa.ExtensionArray.from_storage(ext_type, storage_arr) + table = pa.Table.from_arrays([arr], names=["ext"]) + lance.write_table(table, str(tmp_path / "test.lance")) + table = lance.dataset(str(tmp_path / "test.lance")).to_table() + assert table["ext"].type == ext_type + assert table["ext"].to_pylist() == storage_arr.to_pylist() diff --git a/python/lance/types.py b/python/lance/types.py new file mode 100644 index 0000000000..949124411d --- /dev/null +++ b/python/lance/types.py @@ -0,0 +1,135 @@ +# Copyright 2022 Lance Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Arrow extension types for Lance +""" +import platform +from abc import ABC, abstractproperty + +import pandas as pd +import pyarrow as pa +from pyarrow import ArrowKeyError + +__all__ = ["ImageType", "ImageUriType", "ImageBinaryType", "Point2dType", "Box2dType"] + + +class LanceType(pa.ExtensionType, ABC): + def __init__(self, storage_type, extension_name): + if platform.system() != "Linux": + raise NotImplementedError( + "Extension types are enabled for linux only for now" + ) + super(LanceType, self).__init__(storage_type, extension_name) + + +class ImageType(LanceType): + """ + Base type for Image's. Images can be either stored as a Uri pointer or + bytes directly. Use ImageType.from_storage(storage_arr.type) to get an + instance of the correct type (only works for utf8 or binary/large-binary + storage types). + """ + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def from_storage(cls, storage_type): + # TODO consider parameterizing types to support utf* variants + # and also large binary (for geo or medical imaging) + if storage_type == pa.utf8(): + return ImageUriType() + elif storage_type == pa.binary(): + return ImageBinaryType() + else: + raise NotImplementedError(f"Unrecognized image storage type {storage_type}") + + +class ImageUriType(ImageType): + """ + Represents an externalized Image containing just the uri. Storage type is + utf8 + """ + + def __init__(self): + super(ImageUriType, self).__init__(pa.utf8(), "image[uri]") + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return ImageUriType() + + +class ImageBinaryType(ImageType): + """ + Represents an inlined Image containing the actual image bytes. Storage type + is binary. + + TODO: add support for large binary + """ + + def __init__(self): + super(ImageBinaryType, self).__init__(pa.binary(), "image[binary]") + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return ImageBinaryType() + + +# TODO turn these into fixed sized list arrays once GH#101 is done +class Point2dType(LanceType): + """ + A Point in 2D space. Represented as 2 floating point numbers + """ + + def __init__(self): + super(Point2dType, self).__init__(pa.list_(pa.float64()), "point2d") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return Point2dType() + + +# TODO turn these into fixed sized list arrays once GH#101 is done +class Box2dType(LanceType): + """ + A rectangular box in 2D space (usually used for bounding boxes). + Represented as 2 Point2Ds (top-left and bottom-right corners) + """ + + def __init__(self): + super(Box2dType, self).__init__(pa.list_(pa.float64()), "box2d") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return Box2dType() + + +def register_extension_types(): + if platform.system() != "Linux": + raise NotImplementedError("Extension types are only supported on Linux for now") + try: + pa.register_extension_type(ImageUriType()) + pa.register_extension_type(ImageBinaryType()) + pa.register_extension_type(Point2dType()) + pa.register_extension_type(Box2dType()) + except ArrowKeyError: + # already registered + pass