From 294434df7eb12aa46874bf071732a20c7dc1cf9f Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 7 Sep 2022 09:54:24 -0700 Subject: [PATCH] add label type and fmt pylance using isort/black --- python/benchmarks/bench_utils.py | 7 +- python/benchmarks/coco.py | 66 ++++++----- python/benchmarks/oxford_pet.py | 42 +++---- python/benchmarks/parse_bdd100k.py | 25 ++++- python/benchmarks/parse_coco.py | 1 - python/benchmarks/parse_pet.py | 175 +++++++++++++++++------------ python/lance/_lib.pyx | 13 ++- python/lance/tests/test_api.py | 12 +- python/lance/tests/test_types.py | 11 ++ python/lance/types.py | 28 ++++- python/setup.py | 4 +- 11 files changed, 239 insertions(+), 145 deletions(-) diff --git a/python/benchmarks/bench_utils.py b/python/benchmarks/bench_utils.py index 1eb0a5d2a1..67cf37ab9f 100644 --- a/python/benchmarks/bench_utils.py +++ b/python/benchmarks/bench_utils.py @@ -12,20 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing as mp import os import pathlib +import time from abc import ABC, abstractmethod from functools import wraps -import multiprocessing as mp from typing import Iterable, Union import click import pandas as pd -import time - import pyarrow as pa -import pyarrow.fs import pyarrow.dataset as ds +import pyarrow.fs import pyarrow.parquet as pq import lance diff --git a/python/benchmarks/coco.py b/python/benchmarks/coco.py index 0abb8db517..b90e61dbd3 100755 --- a/python/benchmarks/coco.py +++ b/python/benchmarks/coco.py @@ -4,38 +4,38 @@ import duckdb import pandas as pd - -import lance import pyarrow.compute as pc import pyarrow.dataset as ds -from bench_utils import download_uris, get_uri, get_dataset, BenchmarkSuite +from bench_utils import BenchmarkSuite, download_uris, get_dataset, get_uri from parse_coco import CocoConverter +import lance + coco_benchmarks = BenchmarkSuite("coco") -@coco_benchmarks.benchmark("label_distribution", key=['fmt', 'flavor']) +@coco_benchmarks.benchmark("label_distribution", key=["fmt", "flavor"]) def label_distribution(base_uri: str, fmt: str, flavor: str = None): - if fmt == 'raw': + if fmt == "raw": return _label_distribution_raw(base_uri) - elif fmt == 'lance': + elif fmt == "lance": uri = get_uri(base_uri, "coco", fmt, flavor) dataset = get_dataset(uri) return _label_distribution_lance(dataset) - elif fmt == 'parquet': + elif fmt == "parquet": uri = get_uri(base_uri, "coco", fmt, flavor) dataset = get_dataset(uri) return _label_distribution_duckdb(dataset) raise NotImplementedError() -@coco_benchmarks.benchmark("filter_data", key=['fmt', 'flavor']) +@coco_benchmarks.benchmark("filter_data", key=["fmt", "flavor"]) def filter_data(base_uri: str, fmt: str, flavor: str = None): - if fmt == 'raw': + if fmt == "raw": return _filter_data_raw(base_uri) - elif fmt == 'lance': + elif fmt == "lance": return _filter_data_lance(base_uri, flavor=flavor) - elif fmt == 'parquet': + elif fmt == "parquet": return _filter_data_parquet(base_uri, flavor=flavor) raise NotImplementedError() @@ -55,40 +55,50 @@ def _filter_data_raw(base_uri: str, klass="cat", offset=20, limit=50): df = c.read_metadata() mask = df.annotations.apply(lambda ann: any([a["name"] == klass for a in ann])) filtered = df.loc[mask, ["image_uri", "annotations"]] - limited = filtered[offset:offset + limit] + limited = filtered[offset : offset + limit] limited.assign(image=download_uris(limited.image_uri)) return limited def _filter_data_lance(base_uri: str, klass="cat", offset=20, limit=50, flavor=None): uri = get_uri(base_uri, "coco", "lance", flavor) - index_scanner = lance.scanner(uri, columns=['image_id', 'annotations.name']) - query = (f"SELECT distinct image_id FROM (" - f" SELECT image_id, UNNEST(annotations) as ann FROM index_scanner" - f") WHERE ann.name == '{klass}'") + index_scanner = lance.scanner(uri, columns=["image_id", "annotations.name"]) + query = ( + f"SELECT distinct image_id FROM (" + f" SELECT image_id, UNNEST(annotations) as ann FROM index_scanner" + f") WHERE ann.name == '{klass}'" + ) filtered_ids = duckdb.query(query).arrow().column("image_id").combine_chunks() - scanner = lance.scanner(uri, ['image_id', 'image', 'annotations.name'], - # filter=pc.field("image_id").isin(filtered_ids), - limit=50, offset=20) + scanner = lance.scanner( + uri, + ["image_id", "image", "annotations.name"], + # filter=pc.field("image_id").isin(filtered_ids), + limit=50, + offset=20, + ) return scanner.to_table().to_pandas() def _filter_data_parquet(base_uri: str, klass="cat", offset=20, limit=50, flavor=None): uri = get_uri(base_uri, "coco", "parquet", flavor) dataset = ds.dataset(uri) - query = (f"SELECT distinct image_id FROM (" - f" SELECT image_id, UNNEST(annotations) as ann FROM dataset" - f") WHERE ann.name == '{klass}'") + query = ( + f"SELECT distinct image_id FROM (" + f" SELECT image_id, UNNEST(annotations) as ann FROM dataset" + f") WHERE ann.name == '{klass}'" + ) filtered_ids = duckdb.query(query).arrow().column("image_id").to_numpy().tolist() - id_string = ','.join([f"'{x}'" for x in filtered_ids]) - return duckdb.query(f"SELECT image, annotations " - f"FROM dataset " - f"WHERE image_id in ({id_string}) " - f"LIMIT 50 OFFSET 20").to_arrow_table() + id_string = ",".join([f"'{x}'" for x in filtered_ids]) + return duckdb.query( + f"SELECT image, annotations " + f"FROM dataset " + f"WHERE image_id in ({id_string}) " + f"LIMIT 50 OFFSET 20" + ).to_arrow_table() def _label_distribution_lance(dataset: ds.Dataset): - scanner = lance.scanner(dataset, columns=['annotations.name']) + scanner = lance.scanner(dataset, columns=["annotations.name"]) return _label_distribution_duckdb(scanner) diff --git a/python/benchmarks/oxford_pet.py b/python/benchmarks/oxford_pet.py index bdea479db4..48369ec2a9 100755 --- a/python/benchmarks/oxford_pet.py +++ b/python/benchmarks/oxford_pet.py @@ -6,51 +6,54 @@ import duckdb import numpy as np import pandas as pd - -import lance import pyarrow as pa import pyarrow.compute as pc import pyarrow.dataset from bench_utils import BenchmarkSuite, download_uris from parse_pet import OxfordPetConverter +import lance + oxford_pet_benchmarks = BenchmarkSuite("oxford_pet") -@oxford_pet_benchmarks.benchmark("label_distribution", key=['fmt', 'flavor']) +@oxford_pet_benchmarks.benchmark("label_distribution", key=["fmt", "flavor"]) def label_distribution(base_uri: str, fmt: str, flavor: Optional[str]): if fmt == "raw": return get_pets_class_distribution(base_uri) - suffix = '' if not flavor else f'_{flavor}' - ds = _get_dataset(os.path.join(base_uri, f'oxford_pet{suffix}.{fmt}'), fmt) + suffix = "" if not flavor else f"_{flavor}" + ds = _get_dataset(os.path.join(base_uri, f"oxford_pet{suffix}.{fmt}"), fmt) query = "SELECT class, count(1) FROM ds GROUP BY 1" return duckdb.query(query).to_df() -@oxford_pet_benchmarks.benchmark("filter_data", key=['fmt', 'flavor']) +@oxford_pet_benchmarks.benchmark("filter_data", key=["fmt", "flavor"]) def filter_data(base_uri: str, fmt: str, flavor: Optional[str]): if fmt == "raw": return get_pets_filtered_data(base_uri) - suffix = '' if not flavor else f'_{flavor}' - uri = os.path.join(base_uri, f'oxford_pet{suffix}.{fmt}') + suffix = "" if not flavor else f"_{flavor}" + uri = os.path.join(base_uri, f"oxford_pet{suffix}.{fmt}") if fmt == "parquet": ds = _get_dataset(uri, fmt) - query = ("SELECT image, class FROM ds WHERE class='pug' " - "LIMIT 50 OFFSET 20") + query = "SELECT image, class FROM ds WHERE class='pug' " "LIMIT 50 OFFSET 20" return duckdb.query(query).to_df() elif fmt == "lance": - scanner = lance.scanner(uri, columns=["image", "class"], - filter=pc.field("class") == "pug", - limit=50, offset=20) + scanner = lance.scanner( + uri, + columns=["image", "class"], + filter=pc.field("class") == "pug", + limit=50, + offset=20, + ) return scanner.to_table().to_pandas() -@oxford_pet_benchmarks.benchmark("area_histogram", key=['fmt', 'flavor']) +@oxford_pet_benchmarks.benchmark("area_histogram", key=["fmt", "flavor"]) def compute_histogram(base_uri: str, fmt: str, flavor: Optional[str]): if fmt == "raw": return area_histogram_raw(base_uri) - suffix = '' if not flavor else f'_{flavor}' - uri = os.path.join(base_uri, f'oxford_pet{suffix}.{fmt}') + suffix = "" if not flavor else f"_{flavor}" + uri = os.path.join(base_uri, f"oxford_pet{suffix}.{fmt}") ds = _get_dataset(uri, fmt) query = "SELECT histogram(size.width * size.height) FROM ds" return duckdb.query(query).to_df() @@ -74,16 +77,15 @@ def get_pets_filtered_data(base_uri, klass="pug", offset=20, limit=50): c = OxfordPetConverter(base_uri) df = c.read_metadata() filtered = df.loc[df["class"] == klass, ["class", "filename"]] - limited: pd.DataFrame = filtered[offset: offset + limit] - uris = [os.path.join(base_uri, f"images/{x}.jpg") - for x in limited.filename.values] + limited: pd.DataFrame = filtered[offset : offset + limit] + uris = [os.path.join(base_uri, f"images/{x}.jpg") for x in limited.filename.values] return limited.assign(images=download_uris(pd.Series(uris))) def area_histogram_raw(base_uri): c = OxfordPetConverter(base_uri) df = c.read_metadata() - sz = pd.json_normalize(df['size']) + sz = pd.json_normalize(df["size"]) query = "SELECT histogram(width * height) FROM sz" return duckdb.query(query).to_df() diff --git a/python/benchmarks/parse_bdd100k.py b/python/benchmarks/parse_bdd100k.py index 3b139672e8..1cd7c48ea1 100755 --- a/python/benchmarks/parse_bdd100k.py +++ b/python/benchmarks/parse_bdd100k.py @@ -6,13 +6,13 @@ from typing import Union import click -import lance import pandas as pd import pyarrow as pa import pyarrow.fs - from bench_utils import DatasetConverter +import lance + class BDD100kConverter(DatasetConverter): def __init__(self, uri_root: Union[str, Path]): @@ -23,12 +23,18 @@ def read_metadata(self) -> pd.DataFrame: for split in ["train", "val"]: annotation = pd.read_json( os.path.join( - self.uri_root, "bdd100k", "labels", f"bdd100k_labels_images_{split}.json" + self.uri_root, + "bdd100k", + "labels", + f"bdd100k_labels_images_{split}.json", ) ) annotation["split"] = split annotation["image_uri"] = annotation["name"].map( - lambda name: os.path.join(self.uri_root, "bdd100k", "images", "100k", split, name)) + lambda name: os.path.join( + self.uri_root, "bdd100k", "images", "100k", split, name + ) + ) frames.append(annotation) return pd.concat(frames) @@ -82,7 +88,12 @@ def get_schema(self): @click.command @click.option("-u", "--base-uri", type=str, required=True, help="Coco dataset root") -@click.option("-f", "--fmt", type=click.Choice(["lance", "parquet"]), help="Output format (parquet or lance)") +@click.option( + "-f", + "--fmt", + type=click.Choice(["lance", "parquet"]), + help="Output format (parquet or lance)", +) @click.option("-e", "--embedded", type=bool, default=True, help="Embed images") @click.option( "-o", @@ -103,7 +114,9 @@ def main(base_uri, fmt, embedded, output_path): partitioning = ["split"] for f in fmt: if embedded: - converter.make_embedded_dataset(df, f, output_path, partitioning=partitioning) + converter.make_embedded_dataset( + df, f, output_path, partitioning=partitioning + ) else: return converter.save_df(df, f, output_path, partitioning=partitioning) diff --git a/python/benchmarks/parse_coco.py b/python/benchmarks/parse_coco.py index 786dce21aa..8d8c3dcedf 100755 --- a/python/benchmarks/parse_coco.py +++ b/python/benchmarks/parse_coco.py @@ -6,7 +6,6 @@ import click import pandas as pd import pyarrow as pa - from bench_utils import DatasetConverter diff --git a/python/benchmarks/parse_pet.py b/python/benchmarks/parse_pet.py index a6f26c15df..a2caa9636b 100755 --- a/python/benchmarks/parse_pet.py +++ b/python/benchmarks/parse_pet.py @@ -18,17 +18,15 @@ import pathlib import sys from typing import Iterable -import click +import click import numpy as np import pandas as pd -import xmltodict - import pyarrow as pa import pyarrow.fs +import xmltodict from bench_utils import DatasetConverter, download_uris, read_file - # Oxford PET has dataset quality issues: # # The following exists in the XMLs but are not part of the list.txt index @@ -43,7 +41,6 @@ class OxfordPetConverter(DatasetConverter): - def __init__(self, uri_root): super(OxfordPetConverter, self).__init__("oxford_pet", uri_root) self._data_quality_issues = {} @@ -60,18 +57,20 @@ def read_metadata(self, check_quality=False) -> pd.DataFrame: split = pd.concat([train, val, test]) split.name = "split" split = split.reset_index() - with_split = df.merge(split, how='left', on="filename") - xml_files = (os.path.join(self.uri_root, "annotations", "xmls/") - + with_split.filename + ".xml") + with_split = df.merge(split, how="left", on="filename") + xml_files = ( + os.path.join(self.uri_root, "annotations", "xmls/") + + with_split.filename + + ".xml" + ) ann_df = pd.DataFrame(download_uris(xml_files, func=_get_xml)) - with_xmls = pd.concat([with_split, ann_df.drop(columns=['filename'])], - axis=1) + with_xmls = pd.concat([with_split, ann_df.drop(columns=["filename"])], axis=1) if check_quality: - trainval = df[df.split.isin(['train', 'val'])] - self._data_quality_issues["missing_xml"] = ( - trainval[trainval.folder.isna()].filename.values.tolist() - ) + trainval = df[df.split.isin(["train", "val"])] + self._data_quality_issues["missing_xml"] = trainval[ + trainval.folder.isna() + ].filename.values.tolist() p = pathlib.Path(self.uri_root) / "annotations" / "xmls" names = pd.Series([p.name[:-4] for p in p.iterdir()]) @@ -79,70 +78,94 @@ def read_metadata(self, check_quality=False) -> pd.DataFrame: self._data_quality_issues["missing_index"] = no_index # TODO lance doesn't support writing booleans yet - with_xmls['segmented'] = with_xmls.segmented.astype(pd.Int8Dtype()) + with_xmls["segmented"] = with_xmls.segmented.astype(pd.Int8Dtype()) return with_xmls def _get_index(self, name: str) -> pd.DataFrame: - list_txt = os.path.join(self.uri_root, - f"annotations/{name}.txt") + list_txt = os.path.join(self.uri_root, f"annotations/{name}.txt") df = pd.read_csv(list_txt, delimiter=" ", comment="#", header=None) df.columns = ["filename", "class", "species", "breed"] return df @staticmethod def _find_split_index(trainval_df): - classnames = trainval_df.filename.str.rsplit('_', 1).str[0].str.lower() + classnames = trainval_df.filename.str.rsplit("_", 1).str[0].str.lower() return np.argmax(classnames < classnames.shift(1)) @staticmethod def _to_category(metadata_df: pd.DataFrame): species_dtype = pd.CategoricalDtype(["Unknown", "Cat", "Dog"]) - metadata_df['species'] = pd.Categorical.from_codes(metadata_df.species, dtype=species_dtype) + metadata_df["species"] = pd.Categorical.from_codes( + metadata_df.species, dtype=species_dtype + ) breeds = metadata_df.filename.str.rsplit("_", 1).str[0].unique() assert len(breeds) == 37 breeds = np.concatenate([["Unknown"], breeds]) class_dtype = pd.CategoricalDtype(breeds) - metadata_df["class"] = pd.Categorical.from_codes(metadata_df["class"], dtype=class_dtype) + metadata_df["class"] = pd.Categorical.from_codes( + metadata_df["class"], dtype=class_dtype + ) return metadata_df def default_dataset_path(self, fmt, flavor=None): suffix = f"_{flavor}" if flavor else "" - return os.path.join(self.uri_root, - f'{self.name}{suffix}.{fmt}') + return os.path.join(self.uri_root, f"{self.name}{suffix}.{fmt}") def image_uris(self, table): - return [os.path.join(self.uri_root, f"images/{x}.jpg") - for x in table["filename"].to_numpy()] + return [ + os.path.join(self.uri_root, f"images/{x}.jpg") + for x in table["filename"].to_numpy() + ] def get_schema(self): - source_schema = pa.struct([ - pa.field("database", pa.string()), - pa.field("annotation", pa.string()), - pa.field("image", pa.string()) - ]) - size_schema = pa.struct([ - pa.field("width", pa.int32()), - pa.field("height", pa.int32()), - pa.field("depth", pa.uint8()) - ]) - bbox = pa.struct([ - pa.field("xmin", pa.int32()), - pa.field("ymin", pa.int32()), - pa.field("xmax", pa.int32()), - pa.field("ymax", pa.int32()), - ]) - object_schema = pa.list_(pa.struct([ - pa.field("name", pa.dictionary(pa.uint8(), pa.string())), - pa.field("pose", pa.dictionary(pa.uint8(), pa.string())), - pa.field("truncated", pa.bool_()), - pa.field("occluded", pa.bool_()), - pa.field("bndbox", bbox), - pa.field("difficult", pa.bool_()) - ])) - names = ["filename", "class", "species", "breed", "split", - "folder", "source", "size", "segmented", "object"] + source_schema = pa.struct( + [ + pa.field("database", pa.string()), + pa.field("annotation", pa.string()), + pa.field("image", pa.string()), + ] + ) + size_schema = pa.struct( + [ + pa.field("width", pa.int32()), + pa.field("height", pa.int32()), + pa.field("depth", pa.uint8()), + ] + ) + bbox = pa.struct( + [ + pa.field("xmin", pa.int32()), + pa.field("ymin", pa.int32()), + pa.field("xmax", pa.int32()), + pa.field("ymax", pa.int32()), + ] + ) + object_schema = pa.list_( + pa.struct( + [ + pa.field("name", pa.dictionary(pa.uint8(), pa.string())), + pa.field("pose", pa.dictionary(pa.uint8(), pa.string())), + pa.field("truncated", pa.bool_()), + pa.field("occluded", pa.bool_()), + pa.field("bndbox", bbox), + pa.field("difficult", pa.bool_()), + ] + ) + ) + names = [ + "filename", + "class", + "species", + "breed", + "split", + "folder", + "source", + "size", + "segmented", + "object", + ] types = [ pa.string(), pa.dictionary(pa.uint8(), pa.string()), @@ -153,32 +176,31 @@ def get_schema(self): source_schema, size_schema, pa.uint8(), - object_schema + object_schema, ] - return pa.schema([pa.field(name, dtype) - for name, dtype in zip(names, types)]) + return pa.schema([pa.field(name, dtype) for name, dtype in zip(names, types)]) def _get_xml(uri): fs, key = pa.fs.FileSystem.from_uri(uri) try: with fs.open_input_file(key) as fh: - dd = xmltodict.parse(fh.read())['annotation'] - if not isinstance(dd['object'], list): - dd['object'] = [dd['object']] - sz = dd['size'] - sz['width'] = int(sz['width']) - sz['height'] = int(sz['height']) - sz['depth'] = int(sz['depth']) - for obj in dd['object']: - obj['truncated'] = bool(int(obj['truncated'])) - obj['occluded'] = bool(int(obj['occluded'])) - obj['difficult'] = bool(int(obj['difficult'])) - obj['bndbox'] = { - 'xmin': int(obj['bndbox']['xmin']), - 'xmax': int(obj['bndbox']['xmax']), - 'ymin': int(obj['bndbox']['ymin']), - 'ymax': int(obj['bndbox']['ymax']), + dd = xmltodict.parse(fh.read())["annotation"] + if not isinstance(dd["object"], list): + dd["object"] = [dd["object"]] + sz = dd["size"] + sz["width"] = int(sz["width"]) + sz["height"] = int(sz["height"]) + sz["depth"] = int(sz["depth"]) + for obj in dd["object"]: + obj["truncated"] = bool(int(obj["truncated"])) + obj["occluded"] = bool(int(obj["occluded"])) + obj["difficult"] = bool(int(obj["difficult"])) + obj["bndbox"] = { + "xmin": int(obj["bndbox"]["xmin"]), + "xmax": int(obj["bndbox"]["xmax"]), + "ymin": int(obj["bndbox"]["ymin"]), + "ymax": int(obj["bndbox"]["ymax"]), } return dd except Exception: @@ -186,10 +208,19 @@ def _get_xml(uri): @click.command -@click.option("-u", "--base-uri", type=str, required=True, help="Oxford Pet dataset root") -@click.option("-f", "--fmt", type=click.Choice(["parquet", "lance"]), help="Output format (parquet or lance)") +@click.option( + "-u", "--base-uri", type=str, required=True, help="Oxford Pet dataset root" +) +@click.option( + "-f", + "--fmt", + type=click.Choice(["parquet", "lance"]), + help="Output format (parquet or lance)", +) @click.option("-e", "--embedded", type=bool, default=True, help="store embedded images") -@click.option("-o", "--output", type=str, default="oxford_pet.lance", help="Output path") +@click.option( + "-o", "--output", type=str, default="oxford_pet.lance", help="Output path" +) def main(base_uri, fmt, embedded, output): known_formats = ["lance", "parquet"] if fmt is not None: diff --git a/python/lance/_lib.pyx b/python/lance/_lib.pyx index 750307e69e..73f2d72231 100644 --- a/python/lance/_lib.pyx +++ b/python/lance/_lib.pyx @@ -4,13 +4,22 @@ from typing import Optional, Union from cython.operator cimport dereference as deref from libcpp cimport bool -from libcpp.memory cimport shared_ptr, const_pointer_cast +from libcpp.memory cimport const_pointer_cast, shared_ptr from libcpp.string cimport string from pathlib import Path from pyarrow import Table -from pyarrow._dataset cimport FileFormat, FileWriteOptions, CFileWriteOptions, CScanner, CDataset, Dataset + +from pyarrow._dataset cimport ( + CDataset, + CFileWriteOptions, + CScanner, + Dataset, + FileFormat, + FileWriteOptions, +) + from pyarrow._dataset import Scanner from pyarrow._compute cimport Expression, _bind diff --git a/python/lance/tests/test_api.py b/python/lance/tests/test_api.py index 47aaa65745..1a4cc10db2 100644 --- a/python/lance/tests/test_api.py +++ b/python/lance/tests/test_api.py @@ -16,10 +16,10 @@ from pathlib import Path import pandas as pd - import pyarrow as pa import pyarrow.dataset as ds -from lance import write_table, dataset, LanceFileFormat + +from lance import LanceFileFormat, dataset, write_table def test_simple_round_trips(tmp_path: Path): @@ -48,7 +48,6 @@ def test_write_categorical_values(tmp_path: Path): assert table == actual - def test_write_dataset(tmp_path: Path): table = pa.Table.from_pandas( pd.DataFrame( @@ -59,12 +58,7 @@ def test_write_dataset(tmp_path: Path): } ) ) - ds.write_dataset( - table, - tmp_path, - partitioning=["split"], - format=LanceFileFormat() - ) + ds.write_dataset(table, tmp_path, partitioning=["split"], format=LanceFileFormat()) part_dirs = [d.name for d in tmp_path.iterdir()] assert set(part_dirs) == set(["a", "b"]) diff --git a/python/lance/tests/test_types.py b/python/lance/tests/test_types.py index 5c4d59fc26..72d44e43e2 100644 --- a/python/lance/tests/test_types.py +++ b/python/lance/tests/test_types.py @@ -14,6 +14,7 @@ import platform +import numpy as np import pyarrow as pa import pytest @@ -52,6 +53,16 @@ def test_box2d(tmp_path): _test_extension_rt(tmp_path, box_type, storage) +def test_label(tmp_path): + label_type = LabelType() + values = ["cat", "dog", "horse", "chicken", "donkey", "pig"] + indices = np.random.randint(0, len(values), 100) + storage = pa.DictionaryArray.from_arrays( + pa.array(indices, type=pa.int8()), pa.array(values, type=pa.string()) + ) + _test_extension_rt(tmp_path, label_type, storage) + + def _test_extension_rt(tmp_path, ext_type, storage_arr): arr = pa.ExtensionArray.from_storage(ext_type, storage_arr) table = pa.Table.from_arrays([arr], names=["ext"]) diff --git a/python/lance/types.py b/python/lance/types.py index 949124411d..a7f54baf34 100644 --- a/python/lance/types.py +++ b/python/lance/types.py @@ -22,7 +22,14 @@ import pyarrow as pa from pyarrow import ArrowKeyError -__all__ = ["ImageType", "ImageUriType", "ImageBinaryType", "Point2dType", "Box2dType"] +__all__ = [ + "ImageType", + "ImageUriType", + "ImageBinaryType", + "Point2dType", + "Box2dType", + "LabelType", +] class LanceType(pa.ExtensionType, ABC): @@ -122,6 +129,24 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized): return Box2dType() +class LabelType(LanceType): + """ + A label used for classification. This is backed by a dictionary type + to make it easier for translating between human-readable strings and + integer classes used in the models + """ + + def __init__(self): + super(LabelType, self).__init__(pa.dictionary(pa.int8(), pa.string()), "label") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, type_self, storage_type, serialized): + return LabelType() + + def register_extension_types(): if platform.system() != "Linux": raise NotImplementedError("Extension types are only supported on Linux for now") @@ -130,6 +155,7 @@ def register_extension_types(): pa.register_extension_type(ImageBinaryType()) pa.register_extension_type(Point2dType()) pa.register_extension_type(Box2dType()) + pa.register_extension_type(LabelType()) except ArrowKeyError: # already registered pass diff --git a/python/setup.py b/python/setup.py index 6821fc8ed8..0e6bdf02ca 100644 --- a/python/setup.py +++ b/python/setup.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path -from setuptools import Extension, find_packages, setup import platform +from pathlib import Path import numpy as np import pyarrow as pa from Cython.Build import cythonize +from setuptools import Extension, find_packages, setup extra_libs = [] # TODO: ciwheelbuild can not find / dont need arrow_python.