diff --git a/cpp/src/lance/arrow/scanner_test.cc b/cpp/src/lance/arrow/scanner_test.cc index 78b2a0aa83..3f6f1dfc6d 100644 --- a/cpp/src/lance/arrow/scanner_test.cc +++ b/cpp/src/lance/arrow/scanner_test.cc @@ -26,8 +26,12 @@ #include "lance/arrow/stl.h" #include "lance/arrow/type.h" +#include "lance/arrow/utils.h" #include "lance/format/schema.h" #include "lance/testing/extension_types.h" +#include "lance/testing/io.h" + +using lance::arrow::ToArray; auto nested_schema = ::arrow::schema({::arrow::field("pk", ::arrow::int32()), ::arrow::field("objects", @@ -202,4 +206,66 @@ TEST_CASE("Test ScanBatchesAsync with batch size") { CHECK(batch.record_batch->num_rows() == kBatchSize); } CHECK(num_batches == kTotalValues / kBatchSize); +} + +// GH-188 +TEST_CASE("Filter over empty list") { + auto values_arr = ToArray({1, 2, 3}).ValueOrDie(); + + auto elem_builder = std::make_shared<::arrow::FloatBuilder>(); + auto list_builder = ::arrow::ListBuilder(::arrow::default_memory_pool(), elem_builder); + CHECK(list_builder.Append().ok()); + CHECK(elem_builder->AppendValues({0.1, 0.2}).ok()); + CHECK(list_builder.AppendNull().ok()); + CHECK(list_builder.Append().ok()); + CHECK(elem_builder->Append(11.1).ok()); + auto list_arr = list_builder.Finish().ValueOrDie(); + + auto schema = ::arrow::schema({::arrow::field("ints", ::arrow::int32()), + ::arrow::field("floats", ::arrow::list(::arrow::float32()))}); + auto t = ::arrow::Table::Make(schema, {values_arr, list_arr}); + + auto dataset = lance::testing::MakeDataset(t).ValueOrDie(); + auto scan_builder = dataset->NewScan().ValueOrDie(); + + // This filter should result in an empty list array + CHECK(scan_builder + ->Filter(::arrow::compute::equal(::arrow::compute::field_ref("ints"), + ::arrow::compute::literal(100))) + .ok()); + auto scanner = scan_builder->Finish().ValueOrDie(); + + auto actual = scanner->ToTable().ValueOrDie(); + CHECK(actual->num_rows() == 0); + CHECK(t->schema()->Equals(actual->schema())); +} + +TEST_CASE("Filter with limit") { + auto values_arr = ToArray({1, 2, 3}).ValueOrDie(); + + auto elem_builder = std::make_shared<::arrow::FloatBuilder>(); + auto list_builder = ::arrow::ListBuilder(::arrow::default_memory_pool(), elem_builder); + CHECK(list_builder.Append().ok()); + CHECK(elem_builder->AppendValues({0.1, 0.2}).ok()); + CHECK(list_builder.AppendNull().ok()); + CHECK(list_builder.Append().ok()); + CHECK(elem_builder->Append(11.1).ok()); + auto list_arr = list_builder.Finish().ValueOrDie(); + + auto schema = ::arrow::schema({::arrow::field("ints", ::arrow::int32()), + ::arrow::field("floats", ::arrow::list(::arrow::float32()))}); + auto t = ::arrow::Table::Make(schema, {values_arr, list_arr}); + auto dataset = lance::testing::MakeDataset(t).ValueOrDie(); + auto scan_builder = lance::arrow::ScannerBuilder(dataset); + CHECK(scan_builder + .Filter(::arrow::compute::equal(::arrow::compute::field_ref("ints"), + ::arrow::compute::literal(100))) + .ok()); + CHECK(scan_builder.Limit(20).ok()); + + auto scanner = scan_builder.Finish().ValueOrDie(); + + auto actual = scanner->ToTable().ValueOrDie(); + CHECK(actual->num_rows() == 0); + CHECK(t->schema()->Equals(actual->schema())); } \ No newline at end of file diff --git a/cpp/src/lance/encodings/plain.cc b/cpp/src/lance/encodings/plain.cc index e53a87d0c7..3aabc4c132 100644 --- a/cpp/src/lance/encodings/plain.cc +++ b/cpp/src/lance/encodings/plain.cc @@ -153,10 +153,15 @@ class PlainDecoderImpl : public Decoder { return Decoder::Take(indices); } + if (indices->length() == 0) { + return MakeEmpty(); + } + int32_t start = indices->Value(0); int32_t length = indices->Value(indices->length() - 1) - start + 1; if (indices->length() == 0 || start < 0 || start + length > length_) { - return ::arrow::Status::Invalid("PlainDecoder::Take: Indices array is not valid"); + return ::arrow::Status::Invalid(fmt::format( + "PlainDecoder::Take: Indices array is not valid: start={}, length={}", start, length)); } // For the simplicity, we read all data in batch to reduce random I/O. // And apply indices later. @@ -187,8 +192,7 @@ class PlainDecoderImpl : public Decoder { } ::arrow::Result> MakeEmpty() const { - ARROW_ASSIGN_OR_RAISE(auto buffer, ::arrow::AllocateBuffer(0)); - return std::make_shared(type_, 0, std::move(buffer)); + return ::arrow::MakeEmptyArray(type_, pool_); } private: @@ -265,7 +269,7 @@ class FixedSizeListPlainDecoderImpl : public Decoder { PlainDecoder::~PlainDecoder() {} ::arrow::Status PlainDecoder::Init() { - assert (!arrow::is_extension(type_)); + assert(!arrow::is_extension(type_)); switch (type_->id()) { case ::arrow::Type::BOOL: impl_.reset(new BooleanPlainDecoderImpl(infile_, type_)); diff --git a/cpp/src/lance/format/schema.cc b/cpp/src/lance/format/schema.cc index 16b7acce58..8330a50e98 100644 --- a/cpp/src/lance/format/schema.cc +++ b/cpp/src/lance/format/schema.cc @@ -477,7 +477,7 @@ ::arrow::Result> Schema::Project( auto actual_field = GetField(components[0]); if (!actual_field) { - return ::arrow::Status::Invalid("Field {} dose not exist.", name); + continue; } auto view_field = view->GetField(components[0]); if (!view_field) { @@ -495,7 +495,7 @@ ::arrow::Result> Schema::Project( for (auto& arrow_field : arrow_schema.fields()) { auto field = GetField(arrow_field->name()); if (!field) { - return ::arrow::Status::Invalid(fmt::format("Field {} dose not exist", arrow_field->name())); + continue; } auto proj_field = field->Project(arrow_field); projection->AddField(proj_field); diff --git a/cpp/src/lance/io/CMakeLists.txt b/cpp/src/lance/io/CMakeLists.txt index 39b952d8d5..24d38d9053 100644 --- a/cpp/src/lance/io/CMakeLists.txt +++ b/cpp/src/lance/io/CMakeLists.txt @@ -36,3 +36,4 @@ target_include_directories(io SYSTEM PRIVATE ${Protobuf_INCLUDE_DIR}) add_dependencies(io format lance_io_exec) add_lance_test(reader_test) +add_lance_test(record_batch_reader_test) diff --git a/cpp/src/lance/io/exec/base.cc b/cpp/src/lance/io/exec/base.cc index 13b8b9d863..2bdba5c155 100644 --- a/cpp/src/lance/io/exec/base.cc +++ b/cpp/src/lance/io/exec/base.cc @@ -14,20 +14,33 @@ #include "lance/io/exec/base.h" +#include + #include + namespace lance::io::exec { ScanBatch ScanBatch::Null() { return ScanBatch(nullptr, -1); } -ScanBatch::ScanBatch(std::shared_ptr<::arrow::RecordBatch> records, int32_t bid) - : batch(records), batch_id(bid) {} +ScanBatch::ScanBatch(std::shared_ptr<::arrow::RecordBatch> records, + int32_t bid, + std::shared_ptr<::arrow::Int32Array> idx) + : batch(std::move(records)), batch_id(bid), indices(std::move(idx)) {} + +ScanBatch ScanBatch::Slice(int64_t offset, int64_t length) const { + auto sliced_batch = batch->Slice(offset, length); + decltype(indices) sliced_indices; + if (indices) { + sliced_indices = std::dynamic_pointer_cast<::arrow::Int32Array>(indices->Slice(offset, length)); + } + return ScanBatch(sliced_batch, batch_id, sliced_indices); +} -ScanBatch ScanBatch::Filtered(std::shared_ptr<::arrow::RecordBatch> records, - int32_t batch_id, - std::shared_ptr<::arrow::Int32Array> indices) { - auto batch = ScanBatch(records, batch_id); - batch.indices = std::move(indices); - return batch; +int64_t ScanBatch::length() const { + if (!batch) { + return 0; + } + return batch->num_rows(); } } // namespace lance::io::exec \ No newline at end of file diff --git a/cpp/src/lance/io/exec/base.h b/cpp/src/lance/io/exec/base.h index be043a6e33..c61a32cb74 100644 --- a/cpp/src/lance/io/exec/base.h +++ b/cpp/src/lance/io/exec/base.h @@ -40,14 +40,6 @@ struct ScanBatch { /// Return a null ScanBatch indicates EOF. static ScanBatch Null(); - /// Constructor with a record batch and batch id. - /// - /// \param records A record batch of values to return - /// \param batch_id the id of the batch - static ScanBatch Filtered(std::shared_ptr<::arrow::RecordBatch> records, - int32_t batch_id, - std::shared_ptr<::arrow::Int32Array> indices); - /// Construct an empty response. ScanBatch() = default; @@ -55,10 +47,19 @@ struct ScanBatch { /// /// \param records A record batch of values to return /// \param batch_id the id of the batch - ScanBatch(std::shared_ptr<::arrow::RecordBatch> records, int32_t batch_id); + /// \param indices the indices from filter. Optional + ScanBatch(std::shared_ptr<::arrow::RecordBatch> records, + int32_t batch_id, + std::shared_ptr<::arrow::Int32Array> indices = nullptr); /// Returns True if the end of file is reached. bool eof() const { return !batch; } + + /// Make a zero-copy slice from this batch. + ScanBatch Slice(int64_t offset, int64_t length) const; + + /// The length of this batch. + int64_t length() const; }; /// I/O execute base node. diff --git a/cpp/src/lance/io/exec/filter.cc b/cpp/src/lance/io/exec/filter.cc index f6af676ba5..c21fe7b6e6 100644 --- a/cpp/src/lance/io/exec/filter.cc +++ b/cpp/src/lance/io/exec/filter.cc @@ -40,13 +40,13 @@ ::arrow::Result Filter::Next() { if (batch.eof()) { return ScanBatch::Null(); } - if (batch.batch->num_rows() == 0) { + if (batch.length() == 0) { return batch; } ARROW_ASSIGN_OR_RAISE(auto indices_and_values, Apply(*batch.batch)); auto [indices, values] = indices_and_values; ARROW_ASSIGN_OR_RAISE(auto values_arr, values->ToStructArray()); - return ScanBatch::Filtered(values, batch.batch_id, indices); + return ScanBatch(values, batch.batch_id, indices); } ::arrow::Result< diff --git a/cpp/src/lance/io/exec/limit.cc b/cpp/src/lance/io/exec/limit.cc index 1467020463..cce5c6826c 100644 --- a/cpp/src/lance/io/exec/limit.cc +++ b/cpp/src/lance/io/exec/limit.cc @@ -14,6 +14,7 @@ #include "lance/io/exec/limit.h" +#include #include #include @@ -51,18 +52,18 @@ ::arrow::Result Limit::Next() { return batch; } // Find intersection of two ranges (offset, offset + limit) and (seen, seen + batch_size). - auto batch_size = batch.batch->num_rows(); + auto batch_size = batch.length(); auto left = std::max(offset_, seen_); auto right = std::min(seen_ + batch_size, offset_ + limit_); - std::shared_ptr<::arrow::RecordBatch> record_batch; + ScanBatch limited_batch; if (left < right) { - record_batch = batch.batch->Slice(left - seen_, right - left); + limited_batch = batch.Slice(left - seen_, right - left); } else { /// No intersection, skip the whole batch. - ARROW_ASSIGN_OR_RAISE(record_batch, ::arrow::RecordBatch::MakeEmpty(batch.batch->schema())); + limited_batch = batch.Slice(0, 0); } - seen_ += batch_size; - return ScanBatch{record_batch, batch.batch_id}; + seen_ += batch.length(); + return limited_batch; } std::string Limit::ToString() const { diff --git a/cpp/src/lance/io/reader.cc b/cpp/src/lance/io/reader.cc index d36376e214..7f3f0ed02d 100644 --- a/cpp/src/lance/io/reader.cc +++ b/cpp/src/lance/io/reader.cc @@ -15,6 +15,7 @@ #include "lance/io/reader.h" #include +#include #include #include #include @@ -27,6 +28,7 @@ #include #include "lance/arrow/type.h" +#include "lance/arrow/utils.h" #include "lance/encodings/binary.h" #include "lance/encodings/plain.h" #include "lance/format/format.h" @@ -353,8 +355,7 @@ ::arrow::Result> FileReader::GetListArray( // TODO: GH-39. We should improve the read behavior to use indices to save some I/Os. auto& indices = params.indices.value(); if (indices->length() == 0) { - return ::arrow::Status::IndexError(fmt::format( - "FileReader::GetListArray: indices is empty: field={}({})", field->name(), field->id())); + return ::arrow::MakeEmptyArray(field->type()); } auto start = static_cast(indices->Value(0)); auto length = static_cast(indices->Value(indices->length() - 1) - start + 1); diff --git a/cpp/src/lance/io/reader_test.cc b/cpp/src/lance/io/reader_test.cc index 03349bfae2..0fc7c50fb0 100644 --- a/cpp/src/lance/io/reader_test.cc +++ b/cpp/src/lance/io/reader_test.cc @@ -15,6 +15,7 @@ #include "lance/io/reader.h" #include +#include #include #include #include @@ -24,6 +25,7 @@ #include "lance/arrow/stl.h" #include "lance/arrow/writer.h" +#include "lance/testing/io.h" TEST_CASE("Test List Array With Nulls") { auto int_builder = std::make_shared<::arrow::Int32Builder>(); @@ -97,4 +99,30 @@ TEST_CASE("Get List Array With Indices") { .ValueOrDie(); CHECK(batch->Equals(*expected_table->CombineChunksToBatch().ValueOrDie())); } +} + +TEST_CASE("Filter over dictionary base") { + auto indices = lance::arrow::ToArray({0, 1, 1, 2}).ValueOrDie(); + auto dict = lance::arrow::ToArray({"car", "horse", "plane", "bike", "cat"}).ValueOrDie(); + auto dict_arr = + ::arrow::DictionaryArray::FromArrays(::arrow::dictionary(::arrow::int8(), ::arrow::utf8()), + std::static_pointer_cast<::arrow::Array>(indices), + std::static_pointer_cast<::arrow::Array>(dict)) + .ValueOrDie(); + auto value_arr = lance::arrow::ToArray({1, 2, 3, 4, 5}).ValueOrDie(); + + auto schema = ::arrow::schema( + {::arrow::field("category", ::arrow::dictionary(::arrow::int8(), ::arrow::utf8())), + ::arrow::field("value", ::arrow::int32())}); + auto tab = ::arrow::Table::Make(schema, {dict_arr, value_arr}); + + auto dataset = lance::testing::MakeDataset(tab).ValueOrDie(); + auto scan_builder = dataset->NewScan().ValueOrDie(); + CHECK(scan_builder + ->Filter(::arrow::compute::equal(::arrow::compute::field_ref("category"), + ::arrow::compute::literal("bike"))) + .ok()); + auto scanner = scan_builder->Finish().ValueOrDie(); + auto actual = scanner->ToTable().ValueOrDie(); + CHECK(actual->num_rows() == 0); } \ No newline at end of file diff --git a/cpp/src/lance/io/record_batch_reader_test.cc b/cpp/src/lance/io/record_batch_reader_test.cc new file mode 100644 index 0000000000..62615f5735 --- /dev/null +++ b/cpp/src/lance/io/record_batch_reader_test.cc @@ -0,0 +1,56 @@ +// Copyright 2022 Lance Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +#include "lance/arrow/stl.h" +#include "lance/testing/io.h" + +using lance::arrow::ToArray; + +TEST_CASE("Scan partitioned dataset") { + auto value_arr = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); + auto split_arr = ToArray({"train", "train", "eval", "test", "train"}).ValueOrDie(); + + auto schema = ::arrow::schema( + {::arrow::field("value", ::arrow::int32()), ::arrow::field("split", ::arrow::utf8())}); + auto t = ::arrow::Table::Make(schema, {value_arr, split_arr}); + + auto dataset = lance::testing::MakeDataset(t, {"split"}).ValueOrDie(); + auto scanner = dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(); + auto actual = scanner->ToTable().ValueOrDie()->CombineChunks().ValueOrDie(); + auto indices = ::arrow::compute::SortIndices(*actual->GetColumnByName("value")).ValueOrDie(); + auto new_datum = ::arrow::compute::Take(actual, indices).ValueOrDie(); + auto sorted_table = new_datum.table(); + INFO("Expected table: " << t->ToString() << " \nActual table: " << sorted_table->ToString()); + CHECK(t->Equals(*sorted_table)); +} + +TEST_CASE("Scan partitioned dataset with nonexistent column") { + auto value_arr = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); + auto split_arr = ToArray({"train", "train", "eval", "test", "train"}).ValueOrDie(); + + auto schema = ::arrow::schema( + {::arrow::field("value", ::arrow::int32()), ::arrow::field("split", ::arrow::utf8())}); + auto t = ::arrow::Table::Make(schema, {value_arr, split_arr}); + auto dataset = lance::testing::MakeDataset(t, {"split"}).ValueOrDie(); + auto scan_builder = dataset->NewScan().ValueOrDie(); + // Woo column does not exist in the dataset, split column does not exist in the lance file. + CHECK(!scan_builder->Project({"value", "split", "woo"}).ok()); +} diff --git a/cpp/src/lance/testing/io.cc b/cpp/src/lance/testing/io.cc index c83dd6a6d2..66c357479d 100644 --- a/cpp/src/lance/testing/io.cc +++ b/cpp/src/lance/testing/io.cc @@ -14,14 +14,34 @@ #include "lance/testing/io.h" +#include +#include #include #include +#include +#include +#include +#include +#include +#include +#include + +#include "lance/arrow/file_lance.h" #include "lance/arrow/writer.h" #include "lance/io/reader.h" namespace lance::testing { +::arrow::Result MakeTemporaryDir() { + std::string temp = (std::filesystem::temp_directory_path() / "lance-test-XXXXXX"); + auto temp_dir = mkdtemp(temp.data()); + if (temp_dir == nullptr) { + return ::arrow::Status::IOError(strerror(errno)); + } + return std::string(temp_dir); +} + ::arrow::Result> MakeReader( const std::shared_ptr<::arrow::Table>& table) { auto sink = ::arrow::io::BufferOutputStream::Create().ValueOrDie(); @@ -32,6 +52,46 @@ ::arrow::Result> MakeReader( return reader; } +::arrow::Result> MakeDataset( + const std::shared_ptr<::arrow::Table>& table, const std::vector& partitions) { + auto sink = ::arrow::io::BufferOutputStream::Create().ValueOrDie(); + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); + ARROW_ASSIGN_OR_RAISE(auto scanner_builder, dataset->NewScan()); + ARROW_ASSIGN_OR_RAISE(auto scanner, scanner_builder->Finish()); + + auto format = lance::arrow::LanceFileFormat::Make(); + + ::arrow::dataset::FileSystemDatasetWriteOptions write_options; + + auto tmp_dir = "file://" + MakeTemporaryDir().ValueOrDie(); + std::string path; + std::vector> partition_fields; + for (auto& part_col : partitions) { + partition_fields.emplace_back(table->schema()->GetFieldByName(part_col)); + } + auto partition_schema = ::arrow::schema(partition_fields); + auto fs = ::arrow::fs::FileSystemFromUri(tmp_dir, &path).ValueOrDie(); + write_options.file_write_options = format->DefaultWriteOptions(); + write_options.filesystem = fs; + write_options.base_dir = path; + write_options.partitioning = + std::make_shared<::arrow::dataset::HivePartitioning>(partition_schema); + write_options.basename_template = "part{i}.lance"; + + ARROW_RETURN_NOT_OK(::arrow::dataset::FileSystemDataset::Write(write_options, scanner)); + + // Read the dataset back + ::arrow::fs::FileSelector selector; + selector.base_dir = write_options.base_dir; + selector.recursive = true; + ::arrow::dataset::FileSystemFactoryOptions factory_options; + factory_options.partitioning = write_options.partitioning; + ARROW_ASSIGN_OR_RAISE(auto factory, + ::arrow::dataset::FileSystemDatasetFactory::Make( + fs, selector, format, factory_options)); + return factory->Finish(); +} + TableScan::TableScan(const ::arrow::Table& table, int64_t batch_size) : reader_(new ::arrow::TableBatchReader(table)) { reader_->set_chunksize(batch_size); @@ -45,7 +105,7 @@ ::arrow::Result TableScan::Next() { if (!reader_) { return io::exec::ScanBatch{}; } - + std::shared_ptr<::arrow::RecordBatch> batch; ARROW_RETURN_NOT_OK(reader_->ReadNext(&batch)); return io::exec::ScanBatch{ diff --git a/cpp/src/lance/testing/io.h b/cpp/src/lance/testing/io.h index 3e703eeb6d..df73dab044 100644 --- a/cpp/src/lance/testing/io.h +++ b/cpp/src/lance/testing/io.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -25,10 +26,22 @@ namespace lance::testing { +/// Make temporary directory and returns the directory path. +::arrow::Result MakeTemporaryDir(); + /// Make lance::io::FileReader from an Arrow Table. ::arrow::Result> MakeReader( const std::shared_ptr<::arrow::Table>& table); +/// Make a FileSystem Dataset from the table. +/// +/// \param table The table to write +/// \param partitions the column names of partitioning. +/// \return a FileSystem Dataset with lance format. +::arrow::Result> MakeDataset( + const std::shared_ptr<::arrow::Table>& table, + const std::vector& partitions = {}); + /// A ExecNode that scans a Table in memory. /// /// This node can be used in test without creating files. diff --git a/python/benchmarks/bench_utils.py b/python/benchmarks/bench_utils.py index 739a797fa2..22073c1c73 100644 --- a/python/benchmarks/bench_utils.py +++ b/python/benchmarks/bench_utils.py @@ -267,3 +267,68 @@ def make_embedded_dataset( @abstractmethod def get_schema(self): pass + + @classmethod + def create_main(cls): + FORMATS = click.Choice(["lance", "parquet"]) + + @click.command() + @click.argument("base_uri") + @click.option( + "-f", + "--fmt", + type=FORMATS, + default="lance", + help="Output format (parquet or lance)", + ) + @click.option("-e", "--embedded", type=bool, default=True, help="Embed images") + @click.option( + "-g", + "--group-size", + type=int, + default=1024, + help="group size", + show_default=True, + ) + @click.option( + "--max-rows-per-file", + type=int, + default=0, + help="max rows per file", + show_default=True, + ) + @click.option( + "-o", + "--output-path", + type=str, + help="Output path. Default is under the base_uri", + ) + def main( + base_uri, + fmt, + embedded, + output_path, + group_size: int, + max_rows_per_file: int, + ): + converter = cls(base_uri) + df = converter.read_metadata() + known_formats = ["lance", "parquet"] + if fmt is not None: + assert fmt in known_formats + fmt = [fmt] + else: + fmt = known_formats + + kwargs = { + "existing_data_behavior": "overwrite_or_ignore", + "max_rows_per_group": group_size, + "max_rows_per_file": max_rows_per_file, + } + for f in fmt: + if embedded: + converter.make_embedded_dataset(df, f, output_path, **kwargs) + else: + return converter.save_df(df, f, output_path, **kwargs) + + return main diff --git a/python/benchmarks/coco/datagen.py b/python/benchmarks/coco/datagen.py index 38a86dbd21..a136740175 100755 --- a/python/benchmarks/coco/datagen.py +++ b/python/benchmarks/coco/datagen.py @@ -9,13 +9,10 @@ sys.path.append("..") -import click import pandas as pd import pyarrow as pa from bench_utils import DatasetConverter -import lance -import lance.types from lance.types import ImageType @@ -232,71 +229,6 @@ def _aggregate_annotations(annotations): return ret -@click.command() -@click.argument("base_uri") -@click.option( - "-v", "--version", type=str, default="2017", help="Dataset version. Default 2017" -) -@click.option( - "-f", - "--fmt", - type=click.Choice(["lance", "parquet"]), - default="lance", - help="Output format (parquet or lance)", -) -@click.option("-e", "--embedded", type=bool, default=True, help="Embed images") -@click.option( - "-g", - "--group-size", - type=int, - default=1024, - help="set the group size", - show_default=True, -) -@click.option( - "--max-rows-per-file", - type=int, - default=0, - help="set the max rows per file", - show_default=True, -) -@click.option( - "-o", - "--output-path", - type=str, - help="Output path. Default is {base_uri}/coco_links.{fmt}", -) -def main( - base_uri, - version, - fmt, - embedded, - output_path, - group_size: int, - max_rows_per_file: int, -): - converter = CocoConverter(base_uri, version=version) - df = converter.read_metadata() - known_formats = ["lance", "parquet"] - if fmt is not None: - assert fmt in known_formats - fmt = [fmt] - else: - fmt = known_formats - - kwargs = { - "existing_data_behavior": "overwrite_or_ignore", - "partitioning": ["split"], - "partitioning_flavor": "hive", - "max_rows_per_group": group_size, - "max_rows_per_file": max_rows_per_file, - } - for f in fmt: - if embedded: - converter.make_embedded_dataset(df, f, output_path, **kwargs) - else: - return converter.save_df(df, f, output_path, **kwargs) - - if __name__ == "__main__": + main = CocoConverter.create_main() main() diff --git a/python/benchmarks/oxford_pet/datagen.py b/python/benchmarks/oxford_pet/datagen.py index 01fb664d13..c55de12f27 100755 --- a/python/benchmarks/oxford_pet/datagen.py +++ b/python/benchmarks/oxford_pet/datagen.py @@ -20,7 +20,6 @@ sys.path.append("..") -import click import numpy as np import pandas as pd import pyarrow as pa @@ -211,49 +210,6 @@ def _get_xml(uri: str): return {} -@click.command -@click.option( - "-u", "--base-uri", type=str, required=True, help="Oxford Pet dataset root" -) -@click.option( - "-f", - "--fmt", - type=click.Choice(["parquet", "lance"]), - help="Output format (parquet or lance)", -) -@click.option("-e", "--embedded", type=bool, default=True, help="store embedded images") -@click.option( - "-g", "--group-size", type=int, default=1024, help="set max_rows_per_group in arrow" -) -@click.option( - "--max-rows-per-file", type=int, default=0, help="set max_rows_per_file in arrow" -) -@click.option( - "-o", "--output", type=str, default="oxford_pet.lance", help="Output path" -) -def main(base_uri, fmt, embedded, output, group_size, max_rows_per_file): - known_formats = ["lance", "parquet"] - if fmt is not None: - assert fmt in known_formats - fmt = [fmt] - else: - fmt = known_formats - converter = OxfordPetConverter(base_uri) - df = converter.read_metadata() - for f in fmt: - if embedded: - converter.make_embedded_dataset( - df, - f, - output_path=output, - partitioning=["split"], - existing_data_behavior="overwrite_or_ignore", - max_rows_per_group=group_size, - max_rows_per_file=max_rows_per_file, # Create enough files for parallelism - ) - else: - converter.save_df(df, f, output_path=output, partitioning=["split"]) - - if __name__ == "__main__": + main = OxfordPetConverter.create_main() main() diff --git a/python/lance/pytorch/__init__.py b/python/lance/pytorch/__init__.py index db8ddda5cb..943e243bd4 100644 --- a/python/lance/pytorch/__init__.py +++ b/python/lance/pytorch/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. """PyTorch Integration -""" \ No newline at end of file +""" diff --git a/python/lance/pytorch/data.py b/python/lance/pytorch/data.py index 25f5b69ca1..acef87e507 100644 --- a/python/lance/pytorch/data.py +++ b/python/lance/pytorch/data.py @@ -43,6 +43,8 @@ def _data_to_tensor(data: Any) -> Union[torch.Tensor, PIL.Image.Image]: return data.to_pil() elif isinstance(data, dict): return {k: to_tensor(v) for k, v in data.items()} + elif isinstance(data, str): + return data else: return torch.tensor(data) diff --git a/python/lance/tests/test_pytorch.py b/python/lance/tests/test_pytorch.py index 82ae410288..7ba4f94c80 100644 --- a/python/lance/tests/test_pytorch.py +++ b/python/lance/tests/test_pytorch.py @@ -22,6 +22,7 @@ import pandas as pd import PIL import pyarrow as pa +import pyarrow.compute as pc import lance from lance.pytorch.data import LanceDataset @@ -68,3 +69,21 @@ def test_dataset_with_ext_types(tmp_path: Path): images, labels = batch assert all([isinstance(p, PIL.Image.Image) for p in images]) assert torch.equal(labels, torch.tensor([0, 1, 2, 0], dtype=torch.int8)) + + +def test_data_loader_with_filter(tmp_path: Path): + torch.Tensor([1, 2, 3]) + ids = pa.array(range(10)) + values = pa.array(range(10, 20)) + split = pa.array(["train", "val"] * 5) + tab = pa.Table.from_arrays([ids, values, split], names=["id", "value", "split"]) + + lance.write_table(tab, tmp_path / "lance") + + dataset = LanceDataset(tmp_path / "lance", filter=pc.field("split") == "train") + for id, value, split in dataset: + assert split == "train" + assert id % 2 == 0 + assert torch.is_tensor(id) + assert (value - 10) % 2 == 0 + assert torch.is_tensor(value) diff --git a/python/lance/tests/test_types.py b/python/lance/tests/test_types.py index a2a42fa49c..c0136545ef 100644 --- a/python/lance/tests/test_types.py +++ b/python/lance/tests/test_types.py @@ -218,3 +218,10 @@ def test_pickle(tmp_path): pickle.dump(img, fh) with (tmp_path / "image").open("rb") as fh: assert img == pickle.load(fh) + + img = Image.create(bytearray(b"bytes")) + assert isinstance(img, ImageBinary) + with (tmp_path / "image").open("wb") as fh: + pickle.dump(img, fh) + with (tmp_path / "image").open("rb") as fh: + assert img == pickle.load(fh) diff --git a/python/lance/types/image.py b/python/lance/types/image.py index e80bd86def..3726588a21 100644 --- a/python/lance/types/image.py +++ b/python/lance/types/image.py @@ -97,9 +97,12 @@ class Image(ABC): """ @staticmethod - def create(data: Optional[Union[bytes, str]]): + def create(data: Optional[Union[bytes, bytearray, str]]): if pd.isna(data): return None + if isinstance(data, bytearray): + data = bytes(data) + if isinstance(data, bytes): img = ImageBinary(data) elif isinstance(data, str):