From 4c049545ba5f4b1c8f742894de3e1ad8736616a2 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 1 Sep 2022 13:36:15 -0700 Subject: [PATCH 1/2] set BatchSize --- cpp/include/lance/arrow/scanner.h | 4 ++ cpp/src/lance/arrow/scanner.cc | 6 +++ cpp/src/lance/arrow/scanner_test.cc | 74 ++++++++++++++++++++++++++--- 3 files changed, 77 insertions(+), 7 deletions(-) diff --git a/cpp/include/lance/arrow/scanner.h b/cpp/include/lance/arrow/scanner.h index a0656a0351..6e3a46fd0c 100644 --- a/cpp/include/lance/arrow/scanner.h +++ b/cpp/include/lance/arrow/scanner.h @@ -46,6 +46,9 @@ class ScannerBuilder final { /// \param columns Selected column names. void Project(const std::vector& columns); + /// Set batch size to scan. + void BatchSize(int64_t batch_size); + /// Apply Filter void Filter(const ::arrow::compute::Expression& filter); @@ -58,6 +61,7 @@ class ScannerBuilder final { std::shared_ptr<::arrow::dataset::Dataset> dataset_; std::optional> columns_ = std::nullopt; ::arrow::compute::Expression filter_ = ::arrow::compute::literal(true); + std::optional batch_size_ = std::nullopt; std::optional limit_ = std::nullopt; int64_t offset_ = 0; }; diff --git a/cpp/src/lance/arrow/scanner.cc b/cpp/src/lance/arrow/scanner.cc index a8426b962f..61d30a3ee9 100644 --- a/cpp/src/lance/arrow/scanner.cc +++ b/cpp/src/lance/arrow/scanner.cc @@ -33,6 +33,8 @@ void ScannerBuilder::Project(const std::vector& columns) { columns_ void ScannerBuilder::Filter(const ::arrow::compute::Expression& filter) { filter_ = filter; } +void ScannerBuilder::BatchSize(int64_t batch_size) { batch_size_ = batch_size; } + void ScannerBuilder::Limit(int64_t limit, int64_t offset) { limit_ = limit; offset_ = offset; @@ -45,6 +47,10 @@ ::arrow::Result> ScannerBuilder::Fini auto builder = ::arrow::dataset::ScannerBuilder(dataset_); ARROW_RETURN_NOT_OK(builder.Filter(filter_)); + if (batch_size_) { + ARROW_RETURN_NOT_OK(builder.BatchSize(batch_size_.value())); + } + auto fragment_opts = std::make_shared(); fragment_opts->limit = limit_; fragment_opts->offset = offset_; diff --git a/cpp/src/lance/arrow/scanner_test.cc b/cpp/src/lance/arrow/scanner_test.cc index a9a45a3ee2..ba1bff998b 100644 --- a/cpp/src/lance/arrow/scanner_test.cc +++ b/cpp/src/lance/arrow/scanner_test.cc @@ -20,16 +20,15 @@ #include #include #include -#include #include #include -#include "lance/arrow/type.h" +#include "lance/arrow/stl.h" #include "lance/arrow/testing.h" +#include "lance/arrow/type.h" #include "lance/format/schema.h" - auto nested_schema = ::arrow::schema({::arrow::field("pk", ::arrow::int32()), ::arrow::field("objects", ::arrow::list(::arrow::struct_({ @@ -83,7 +82,6 @@ TEST_CASE("Build Scanner with nested struct") { fmt::print("Scanner Options: {}\n", scanner->options()->filter.ToString()); } - std::shared_ptr<::arrow::Table> MakeTable() { auto ext_type = std::make_shared<::lance::testing::ParametricType>(1); ::arrow::StringBuilder stringBuilder; @@ -97,8 +95,8 @@ std::shared_ptr<::arrow::Table> MakeTable() { auto c2 = intBuilder.Finish().ValueOrDie(); intBuilder.Reset(); - auto schema = ::arrow::schema({arrow::field("c1", ::arrow::utf8()), - arrow::field("c2", ext_type)}); + auto schema = + ::arrow::schema({arrow::field("c1", ::arrow::utf8()), arrow::field("c2", ext_type)}); std::vector> cols; cols.push_back(c1); @@ -119,7 +117,6 @@ std::shared_ptr<::arrow::dataset::Scanner> MakeScanner(std::shared_ptr<::arrow:: return scanner; } - TEST_CASE("Scanner with extension") { auto table = MakeTable(); auto ext_type = std::make_shared<::lance::testing::ParametricType>(1); @@ -140,4 +137,67 @@ TEST_CASE("Scanner with extension") { auto actual_table = scanner->ToTable().ValueOrDie(); CHECK(actual_table->schema()->Equals(expected_proj_schema)); CHECK(actual_table->GetColumnByName("c2")->type()->Equals(ext_type)); +} + +::arrow::Result> MakeScannerForBatchScan( + int64_t num_values, int64_t batch_size) { + std::vector values(num_values); + std::iota(values.begin(), values.end(), 0); + auto arr = lance::arrow::ToArray(values).ValueOrDie(); + auto table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("value", ::arrow::int32())}), {arr}); + + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); + auto scanner_builder = lance::arrow::ScannerBuilder(dataset); + scanner_builder.BatchSize(batch_size); + return scanner_builder.Finish(); +} + +TEST_CASE("Test Scanner::ToRecordBatchReader with batch size") { + const int kTotalValues = 100; + const int kBatchSize = 4; + auto scanner = MakeScannerForBatchScan(kTotalValues, kBatchSize).ValueOrDie(); + auto record_batch_reader = scanner->ToRecordBatchReader().ValueOrDie(); + int num_batches = 0; + while (auto batch = record_batch_reader->Next().ValueOrDie()) { + CHECK(batch->num_rows() == kBatchSize); + num_batches++; + } + CHECK(num_batches == kTotalValues / kBatchSize); +} + +TEST_CASE("Test Scanner::ScanBatch with batch size") { + const int kTotalValues = 100; + const int kBatchSize = 4; + auto scanner = MakeScannerForBatchScan(kTotalValues, kBatchSize).ValueOrDie(); + auto batches = scanner->ScanBatches().ValueOrDie(); + int num_batches = 0; + while (true) { + auto batch = batches.Next().ValueOrDie(); + if (!batch.record_batch) { + break; + } + CHECK(batch.record_batch->num_rows() == kBatchSize); + num_batches++; + } + CHECK(num_batches == kTotalValues / kBatchSize); +} + +TEST_CASE("Test ScanBatchesAsync with batch size") { + const int kTotalValues = 100; + const int kBatchSize = 4; + auto scanner = MakeScannerForBatchScan(kTotalValues, kBatchSize).ValueOrDie(); + auto generator = scanner->ScanBatchesAsync().ValueOrDie(); + int num_batches = 0; + while (true) { + auto fut = generator(); + CHECK(fut.Wait(1)); + auto batch = fut.result().ValueOrDie(); + if (!batch.record_batch) { + break; + } + num_batches++; + CHECK(batch.record_batch->num_rows() == kBatchSize); + } + CHECK(num_batches == kTotalValues / kBatchSize); } \ No newline at end of file From c738ffa145485348fae4700fb0c8d1c226d664d1 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 1 Sep 2022 14:17:59 -0700 Subject: [PATCH 2/2] make ScannerBuilder setters return status --- cpp/include/lance/arrow/scanner.h | 19 +++++----- cpp/src/lance/arrow/scanner.cc | 54 +++++++++++++++-------------- cpp/src/lance/arrow/scanner_test.cc | 16 +++++---- cpp/src/lance/io/project_test.cc | 8 +++-- 4 files changed, 50 insertions(+), 47 deletions(-) diff --git a/cpp/include/lance/arrow/scanner.h b/cpp/include/lance/arrow/scanner.h index 6e3a46fd0c..1968e3a47f 100644 --- a/cpp/include/lance/arrow/scanner.h +++ b/cpp/include/lance/arrow/scanner.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -44,26 +45,22 @@ class ScannerBuilder final { /// Project over selected columns. /// /// \param columns Selected column names. - void Project(const std::vector& columns); + ::arrow::Status Project(const std::vector& columns); /// Set batch size to scan. - void BatchSize(int64_t batch_size); + ::arrow::Status BatchSize(int64_t batch_size); /// Apply Filter - void Filter(const ::arrow::compute::Expression& filter); + ::arrow::Status Filter(const ::arrow::compute::Expression& filter); - /// Set limit to the dataset - void Limit(int64_t limit, int64_t offset = 0); + /// Set limit and offset to scan. + ::arrow::Status Limit(int64_t limit, int64_t offset = 0); - ::arrow::Result> Finish() const; + ::arrow::Result> Finish(); private: - std::shared_ptr<::arrow::dataset::Dataset> dataset_; + ::arrow::dataset::ScannerBuilder builder_; std::optional> columns_ = std::nullopt; - ::arrow::compute::Expression filter_ = ::arrow::compute::literal(true); - std::optional batch_size_ = std::nullopt; - std::optional limit_ = std::nullopt; - int64_t offset_ = 0; }; } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/scanner.cc b/cpp/src/lance/arrow/scanner.cc index 61d30a3ee9..7dcd3e598c 100644 --- a/cpp/src/lance/arrow/scanner.cc +++ b/cpp/src/lance/arrow/scanner.cc @@ -27,36 +27,34 @@ namespace lance::arrow { ScannerBuilder::ScannerBuilder(std::shared_ptr<::arrow::dataset::Dataset> dataset) - : dataset_(dataset) {} + : builder_(dataset) {} -void ScannerBuilder::Project(const std::vector& columns) { columns_ = columns; } - -void ScannerBuilder::Filter(const ::arrow::compute::Expression& filter) { filter_ = filter; } - -void ScannerBuilder::BatchSize(int64_t batch_size) { batch_size_ = batch_size; } +::arrow::Status ScannerBuilder::Project(const std::vector& columns) { + columns_ = columns; + return ::arrow::Status::OK(); +} -void ScannerBuilder::Limit(int64_t limit, int64_t offset) { - limit_ = limit; - offset_ = offset; +::arrow::Status ScannerBuilder::Filter(const ::arrow::compute::Expression& filter) { + return builder_.Filter(filter); } -::arrow::Result> ScannerBuilder::Finish() const { - if (offset_ < 0) { - return ::arrow::Status::Invalid("Offset is negative"); - } - auto builder = ::arrow::dataset::ScannerBuilder(dataset_); - ARROW_RETURN_NOT_OK(builder.Filter(filter_)); +::arrow::Status ScannerBuilder::BatchSize(int64_t batch_size) { + return builder_.BatchSize(batch_size); +} - if (batch_size_) { - ARROW_RETURN_NOT_OK(builder.BatchSize(batch_size_.value())); +::arrow::Status ScannerBuilder::Limit(int64_t limit, int64_t offset) { + if (limit <= 0 || offset < 0) { + return ::arrow::Status::Invalid("Limit / offset is invalid: limit=", limit, " offset=", offset); } - auto fragment_opts = std::make_shared(); - fragment_opts->limit = limit_; - fragment_opts->offset = offset_; - ARROW_RETURN_NOT_OK(builder.FragmentScanOptions(fragment_opts)); + fragment_opts->limit = limit; + fragment_opts->offset = offset; + + return builder_.FragmentScanOptions(fragment_opts); +} - ARROW_ASSIGN_OR_RAISE(auto scanner, builder.Finish()); +::arrow::Result> ScannerBuilder::Finish() { + ARROW_ASSIGN_OR_RAISE(auto scanner, builder_.Finish()); /// We do the schema projection manually here to support nested structs. /// Esp. for `list`, supports Spark-like access, for example, @@ -103,10 +101,14 @@ ::arrow::Result> ScannerBuilder::Fini scanner->options()->projection = project_desc.expression; } - if (limit_.has_value()) { - scanner->options()->batch_size = offset_ + limit_.value(); - /// We need to limit the parallelism for Project to calculate LIMIT / Offset - scanner->options()->batch_readahead = 1; + if (scanner->options()->fragment_scan_options) { + auto fso = std::dynamic_pointer_cast( + scanner->options()->fragment_scan_options); + if (fso->limit) { + scanner->options()->batch_size = fso->limit.value(); + /// We need to limit the parallelism for Project to calculate LIMIT / Offset + scanner->options()->batch_readahead = 1; + } } return scanner; diff --git a/cpp/src/lance/arrow/scanner_test.cc b/cpp/src/lance/arrow/scanner_test.cc index ba1bff998b..fc9ca74b8e 100644 --- a/cpp/src/lance/arrow/scanner_test.cc +++ b/cpp/src/lance/arrow/scanner_test.cc @@ -61,10 +61,12 @@ TEST_CASE("Build Scanner with nested struct") { auto table = ::arrow::Table::MakeEmpty(nested_schema).ValueOrDie(); auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto scanner_builder = lance::arrow::ScannerBuilder(dataset); - scanner_builder.Limit(10); - scanner_builder.Project({"objects.val"}); - scanner_builder.Filter(::arrow::compute::equal(::arrow::compute::field_ref({"objects", 0, "val"}), - ::arrow::compute::literal(2))); + CHECK(scanner_builder.Limit(10).ok()); + CHECK(scanner_builder.Project({"objects.val"}).ok()); + CHECK(scanner_builder + .Filter(::arrow::compute::equal(::arrow::compute::field_ref({"objects", 0, "val"}), + ::arrow::compute::literal(2))) + .ok()); auto result = scanner_builder.Finish(); CHECK(result.ok()); auto scanner = result.ValueOrDie(); @@ -107,8 +109,8 @@ std::shared_ptr<::arrow::Table> MakeTable() { std::shared_ptr<::arrow::dataset::Scanner> MakeScanner(std::shared_ptr<::arrow::Table> table) { auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto scanner_builder = lance::arrow::ScannerBuilder(dataset); - scanner_builder.Limit(2); - scanner_builder.Project({"c2"}); + CHECK(scanner_builder.Limit(2).ok()); + CHECK(scanner_builder.Project({"c2"}).ok()); // TODO how can extension types implement comparisons for filtering against storage type? auto result = scanner_builder.Finish(); CHECK(result.ok()); @@ -149,7 +151,7 @@ ::arrow::Result> MakeScannerForBatchS auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto scanner_builder = lance::arrow::ScannerBuilder(dataset); - scanner_builder.BatchSize(batch_size); + CHECK(scanner_builder.BatchSize(batch_size).ok()); return scanner_builder.Finish(); } diff --git a/cpp/src/lance/io/project_test.cc b/cpp/src/lance/io/project_test.cc index bdede8f183..ef4cd538a1 100644 --- a/cpp/src/lance/io/project_test.cc +++ b/cpp/src/lance/io/project_test.cc @@ -46,9 +46,11 @@ TEST_CASE("Project schema") { auto dataset = std::make_shared(tbl); auto scan_builder = lance::arrow::ScannerBuilder(dataset); - scan_builder.Project({"v"}); - scan_builder.Filter( - arrow::compute::equal(arrow::compute::field_ref("v"), arrow::compute::literal(20))); + CHECK(scan_builder.Project({"v"}).ok()); + CHECK(scan_builder + .Filter( + arrow::compute::equal(arrow::compute::field_ref("v"), arrow::compute::literal(20))) + .ok()); auto scanner = scan_builder.Finish().ValueOrDie(); auto lance_schema = lance::format::Schema(schema);