Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting BatchSize via ScanBuilder #135

Merged
merged 2 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions cpp/include/lance/arrow/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <arrow/compute/exec/expression.h>
#include <arrow/dataset/scanner.h>
#include <arrow/dataset/type_fwd.h>
#include <arrow/result.h>

Expand Down Expand Up @@ -44,22 +45,22 @@ class ScannerBuilder final {
/// Project over selected columns.
///
/// \param columns Selected column names.
void Project(const std::vector<std::string>& columns);
::arrow::Status Project(const std::vector<std::string>& columns);

/// Set batch size to scan.
::arrow::Status BatchSize(int64_t batch_size);

/// Apply Filter
void Filter(const ::arrow::compute::Expression& filter);
::arrow::Status Filter(const ::arrow::compute::Expression& filter);

/// Set limit to the dataset
void Limit(int64_t limit, int64_t offset = 0);
/// Set limit and offset to scan.
::arrow::Status Limit(int64_t limit, int64_t offset = 0);

::arrow::Result<std::shared_ptr<::arrow::dataset::Scanner>> Finish() const;
::arrow::Result<std::shared_ptr<::arrow::dataset::Scanner>> Finish();

private:
std::shared_ptr<::arrow::dataset::Dataset> dataset_;
::arrow::dataset::ScannerBuilder builder_;
std::optional<std::vector<std::string>> columns_ = std::nullopt;
::arrow::compute::Expression filter_ = ::arrow::compute::literal(true);
std::optional<int64_t> limit_ = std::nullopt;
int64_t offset_ = 0;
};

} // namespace lance::arrow
48 changes: 28 additions & 20 deletions cpp/src/lance/arrow/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,34 @@
namespace lance::arrow {

ScannerBuilder::ScannerBuilder(std::shared_ptr<::arrow::dataset::Dataset> dataset)
: dataset_(dataset) {}
: builder_(dataset) {}

void ScannerBuilder::Project(const std::vector<std::string>& columns) { columns_ = columns; }
::arrow::Status ScannerBuilder::Project(const std::vector<std::string>& columns) {
columns_ = columns;
return ::arrow::Status::OK();
}

void ScannerBuilder::Filter(const ::arrow::compute::Expression& filter) { filter_ = filter; }
::arrow::Status ScannerBuilder::Filter(const ::arrow::compute::Expression& filter) {
return builder_.Filter(filter);
}

void ScannerBuilder::Limit(int64_t limit, int64_t offset) {
limit_ = limit;
offset_ = offset;
::arrow::Status ScannerBuilder::BatchSize(int64_t batch_size) {
return builder_.BatchSize(batch_size);
}

::arrow::Result<std::shared_ptr<::arrow::dataset::Scanner>> ScannerBuilder::Finish() const {
if (offset_ < 0) {
return ::arrow::Status::Invalid("Offset is negative");
::arrow::Status ScannerBuilder::Limit(int64_t limit, int64_t offset) {
if (limit <= 0 || offset < 0) {
return ::arrow::Status::Invalid("Limit / offset is invalid: limit=", limit, " offset=", offset);
}
auto builder = ::arrow::dataset::ScannerBuilder(dataset_);
ARROW_RETURN_NOT_OK(builder.Filter(filter_));

auto fragment_opts = std::make_shared<LanceFragmentScanOptions>();
fragment_opts->limit = limit_;
fragment_opts->offset = offset_;
ARROW_RETURN_NOT_OK(builder.FragmentScanOptions(fragment_opts));
fragment_opts->limit = limit;
fragment_opts->offset = offset;

ARROW_ASSIGN_OR_RAISE(auto scanner, builder.Finish());
return builder_.FragmentScanOptions(fragment_opts);
}

::arrow::Result<std::shared_ptr<::arrow::dataset::Scanner>> ScannerBuilder::Finish() {
ARROW_ASSIGN_OR_RAISE(auto scanner, builder_.Finish());

/// We do the schema projection manually here to support nested structs.
/// Esp. for `list<struct>`, supports Spark-like access, for example,
Expand Down Expand Up @@ -97,10 +101,14 @@ ::arrow::Result<std::shared_ptr<::arrow::dataset::Scanner>> ScannerBuilder::Fini
scanner->options()->projection = project_desc.expression;
}

if (limit_.has_value()) {
scanner->options()->batch_size = offset_ + limit_.value();
/// We need to limit the parallelism for Project to calculate LIMIT / Offset
scanner->options()->batch_readahead = 1;
if (scanner->options()->fragment_scan_options) {
auto fso = std::dynamic_pointer_cast<LanceFragmentScanOptions>(
scanner->options()->fragment_scan_options);
if (fso->limit) {
scanner->options()->batch_size = fso->limit.value();
/// We need to limit the parallelism for Project to calculate LIMIT / Offset
scanner->options()->batch_readahead = 1;
}
}

return scanner;
Expand Down
88 changes: 75 additions & 13 deletions cpp/src/lance/arrow/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@
#include <arrow/table.h>
#include <arrow/type.h>
#include <fmt/format.h>
#include <fmt/ranges.h>

#include <catch2/catch_test_macros.hpp>
#include <memory>

#include "lance/arrow/type.h"
#include "lance/arrow/stl.h"
#include "lance/arrow/testing.h"
#include "lance/arrow/type.h"
#include "lance/format/schema.h"


auto nested_schema = ::arrow::schema({::arrow::field("pk", ::arrow::int32()),
::arrow::field("objects",
::arrow::list(::arrow::struct_({
Expand Down Expand Up @@ -62,10 +61,12 @@ TEST_CASE("Build Scanner with nested struct") {
auto table = ::arrow::Table::MakeEmpty(nested_schema).ValueOrDie();
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
auto scanner_builder = lance::arrow::ScannerBuilder(dataset);
scanner_builder.Limit(10);
scanner_builder.Project({"objects.val"});
scanner_builder.Filter(::arrow::compute::equal(::arrow::compute::field_ref({"objects", 0, "val"}),
::arrow::compute::literal(2)));
CHECK(scanner_builder.Limit(10).ok());
CHECK(scanner_builder.Project({"objects.val"}).ok());
CHECK(scanner_builder
.Filter(::arrow::compute::equal(::arrow::compute::field_ref({"objects", 0, "val"}),
::arrow::compute::literal(2)))
.ok());
auto result = scanner_builder.Finish();
CHECK(result.ok());
auto scanner = result.ValueOrDie();
Expand All @@ -83,7 +84,6 @@ TEST_CASE("Build Scanner with nested struct") {
fmt::print("Scanner Options: {}\n", scanner->options()->filter.ToString());
}


std::shared_ptr<::arrow::Table> MakeTable() {
auto ext_type = std::make_shared<::lance::testing::ParametricType>(1);
::arrow::StringBuilder stringBuilder;
Expand All @@ -97,8 +97,8 @@ std::shared_ptr<::arrow::Table> MakeTable() {
auto c2 = intBuilder.Finish().ValueOrDie();
intBuilder.Reset();

auto schema = ::arrow::schema({arrow::field("c1", ::arrow::utf8()),
arrow::field("c2", ext_type)});
auto schema =
::arrow::schema({arrow::field("c1", ::arrow::utf8()), arrow::field("c2", ext_type)});

std::vector<std::shared_ptr<::arrow::Array>> cols;
cols.push_back(c1);
Expand All @@ -109,8 +109,8 @@ std::shared_ptr<::arrow::Table> MakeTable() {
std::shared_ptr<::arrow::dataset::Scanner> MakeScanner(std::shared_ptr<::arrow::Table> table) {
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
auto scanner_builder = lance::arrow::ScannerBuilder(dataset);
scanner_builder.Limit(2);
scanner_builder.Project({"c2"});
CHECK(scanner_builder.Limit(2).ok());
CHECK(scanner_builder.Project({"c2"}).ok());
// TODO how can extension types implement comparisons for filtering against storage type?
auto result = scanner_builder.Finish();
CHECK(result.ok());
Expand All @@ -119,7 +119,6 @@ std::shared_ptr<::arrow::dataset::Scanner> MakeScanner(std::shared_ptr<::arrow::
return scanner;
}


TEST_CASE("Scanner with extension") {
auto table = MakeTable();
auto ext_type = std::make_shared<::lance::testing::ParametricType>(1);
Expand All @@ -140,4 +139,67 @@ TEST_CASE("Scanner with extension") {
auto actual_table = scanner->ToTable().ValueOrDie();
CHECK(actual_table->schema()->Equals(expected_proj_schema));
CHECK(actual_table->GetColumnByName("c2")->type()->Equals(ext_type));
}

::arrow::Result<std::shared_ptr<::arrow::dataset::Scanner>> MakeScannerForBatchScan(
int64_t num_values, int64_t batch_size) {
std::vector<int32_t> values(num_values);
std::iota(values.begin(), values.end(), 0);
auto arr = lance::arrow::ToArray(values).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("value", ::arrow::int32())}), {arr});

auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
auto scanner_builder = lance::arrow::ScannerBuilder(dataset);
CHECK(scanner_builder.BatchSize(batch_size).ok());
return scanner_builder.Finish();
}

TEST_CASE("Test Scanner::ToRecordBatchReader with batch size") {
const int kTotalValues = 100;
const int kBatchSize = 4;
auto scanner = MakeScannerForBatchScan(kTotalValues, kBatchSize).ValueOrDie();
auto record_batch_reader = scanner->ToRecordBatchReader().ValueOrDie();
int num_batches = 0;
while (auto batch = record_batch_reader->Next().ValueOrDie()) {
CHECK(batch->num_rows() == kBatchSize);
num_batches++;
}
CHECK(num_batches == kTotalValues / kBatchSize);
}

TEST_CASE("Test Scanner::ScanBatch with batch size") {
const int kTotalValues = 100;
const int kBatchSize = 4;
auto scanner = MakeScannerForBatchScan(kTotalValues, kBatchSize).ValueOrDie();
auto batches = scanner->ScanBatches().ValueOrDie();
int num_batches = 0;
while (true) {
auto batch = batches.Next().ValueOrDie();
if (!batch.record_batch) {
break;
}
CHECK(batch.record_batch->num_rows() == kBatchSize);
num_batches++;
}
CHECK(num_batches == kTotalValues / kBatchSize);
}

TEST_CASE("Test ScanBatchesAsync with batch size") {
const int kTotalValues = 100;
const int kBatchSize = 4;
auto scanner = MakeScannerForBatchScan(kTotalValues, kBatchSize).ValueOrDie();
auto generator = scanner->ScanBatchesAsync().ValueOrDie();
int num_batches = 0;
while (true) {
auto fut = generator();
CHECK(fut.Wait(1));
auto batch = fut.result().ValueOrDie();
if (!batch.record_batch) {
break;
}
num_batches++;
CHECK(batch.record_batch->num_rows() == kBatchSize);
}
CHECK(num_batches == kTotalValues / kBatchSize);
}
8 changes: 5 additions & 3 deletions cpp/src/lance/io/project_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ TEST_CASE("Project schema") {
auto dataset = std::make_shared<arrow::dataset::InMemoryDataset>(tbl);

auto scan_builder = lance::arrow::ScannerBuilder(dataset);
scan_builder.Project({"v"});
scan_builder.Filter(
arrow::compute::equal(arrow::compute::field_ref("v"), arrow::compute::literal(20)));
CHECK(scan_builder.Project({"v"}).ok());
CHECK(scan_builder
.Filter(
arrow::compute::equal(arrow::compute::field_ref("v"), arrow::compute::literal(20)))
.ok());
auto scanner = scan_builder.Finish().ValueOrDie();

auto lance_schema = lance::format::Schema(schema);
Expand Down