Skip to content

Commit

Permalink
add indices field to ScanBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Sep 3, 2022
1 parent 451e5b0 commit d3f97f9
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 46 deletions.
2 changes: 1 addition & 1 deletion cpp/src/lance/io/exec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ add_library(
scan.h
take.cc
take.h
)
base.cc)
target_include_directories(lance_io_exec SYSTEM PRIVATE ${Protobuf_INCLUDE_DIR})
add_dependencies(lance_io_exec format)

Expand Down
33 changes: 33 additions & 0 deletions cpp/src/lance/io/exec/base.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2022 Lance Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lance/io/exec/base.h"

#include <memory>
namespace lance::io::exec {

ScanBatch ScanBatch::Null() { return ScanBatch(nullptr, -1); }

ScanBatch::ScanBatch(std::shared_ptr<::arrow::RecordBatch> records, int32_t bid)
: batch(records), batch_id(bid) {}

ScanBatch ScanBatch::Filtered(std::shared_ptr<::arrow::RecordBatch> records,
int32_t batch_id,
std::shared_ptr<::arrow::Int32Array> indices) {
auto batch = ScanBatch(records, batch_id);
batch.indices = std::move(indices);
return batch;
}

} // namespace lance::io::exec
50 changes: 49 additions & 1 deletion cpp/src/lance/io/exec/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,65 @@

namespace lance::io::exec {

/// Emitted results from each ExecNode
struct ScanBatch {
/// The resulted RecordBatch.
///
/// If it is zero-sized batch, there is no valid values in this batch.
/// It it is nullptr, it reaches the end of the scan.
std::shared_ptr<::arrow::RecordBatch> batch;
int32_t batch_id;

/// The Id of the batch this result belongs to.
int32_t batch_id = -1;

/// Indices returned from the filter.
std::shared_ptr<::arrow::Int32Array> indices;

/// Return a null ScanBatch indicates EOF.
static ScanBatch Null();

/// Constructor with a record batch and batch id.
///
/// \param records A record batch of values to return
/// \param batch_id the id of the batch
static ScanBatch Filtered(std::shared_ptr<::arrow::RecordBatch> records,
int32_t batch_id,
std::shared_ptr<::arrow::Int32Array> indices);

/// Construct an empty response.
ScanBatch() = default;

/// Constructor with a record batch and batch id.
///
/// \param records A record batch of values to return
/// \param batch_id the id of the batch
ScanBatch(std::shared_ptr<::arrow::RecordBatch> records, int32_t batch_id);

/// Returns True if the end of file is reached.
bool eof() const { return !batch; }
};

/// I/O execute base node.
///
/// TODO: investigate to adapt Arrow Acero.
/// https://arrow.apache.org/docs/cpp/streaming_execution.html
///
/// A exec plan is usually starts with Project and ends with Scan.
///
/// \example
/// A few examples of the exec plan tree for common queries.
///
/// SELECT * FROM dataset
/// Project (*) --> Scan (*)
///
/// SELECT a, b FROM dataset WHERE c = 123
/// Project (a, b) -> Take(a,b) -> Filter(c=123) -> Scan(c)
///
/// SELECT a, b FROM dataset LIMIT 200 OFFSET 5000
/// Project (a, b) -> Limit(200, 5000) -> Scan(a, b)
///
/// SELECT a, b, c FROM dataset WHERE c = 123 LIMIT 50 OFFSET 200
/// Project (a, b, c) -> Take(a, b) -> Limit(50, 200) -> Filter(c=123) -> Scan(c)
class ExecNode {
public:
enum Type {
Expand Down
9 changes: 2 additions & 7 deletions cpp/src/lance/io/exec/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,15 @@ bool Filter::HasFilter(const ::arrow::compute::Expression& filter) {
::arrow::Result<ScanBatch> Filter::Next() {
ARROW_ASSIGN_OR_RAISE(auto batch, child_->Next());
if (batch.eof()) {
return ScanBatch{};
return ScanBatch::Null();
}
if (batch.batch->num_rows() == 0) {
return batch;
}
ARROW_ASSIGN_OR_RAISE(auto indices_and_values, Apply(*batch.batch));
auto [indices, values] = indices_and_values;
ARROW_ASSIGN_OR_RAISE(auto values_arr, values->ToStructArray());
ARROW_ASSIGN_OR_RAISE(
auto struct_arr,
::arrow::StructArray::Make({indices, values_arr},
std::vector<std::string>({"indices", "values"})));
ARROW_ASSIGN_OR_RAISE(auto result_batch, ::arrow::RecordBatch::FromStructArray(struct_arr));
return ScanBatch{result_batch, batch.batch_id};
return ScanBatch::Filtered(values, batch.batch_id, indices);
}

::arrow::Result<
Expand Down
25 changes: 10 additions & 15 deletions cpp/src/lance/io/exec/filter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,33 @@ TEST_CASE("value = 32") {
auto bar = lance::arrow::ToArray({1, 2, 32, 0, 32}).ValueOrDie();
auto struct_arr =
::arrow::StructArray::Make({bar}, {::arrow::field("value", ::arrow::int32())}).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("value", ::arrow::int32())}), {bar});
auto schema = ::arrow::schema({::arrow::field("value", ::arrow::int32())});
auto table = ::arrow::Table::Make(schema, {bar});
auto batch = ::arrow::RecordBatch::FromStructArray(struct_arr).ValueOrDie();

auto filter = Filter::Make(expr, TableScan::Make(*table)).ValueOrDie();
auto filtered_batch = filter->Next().ValueOrDie();
auto indices = filtered_batch.batch->GetColumnByName("indices");
auto output = filtered_batch.batch->GetColumnByName("values");
CHECK(indices->Equals(lance::arrow::ToArray({2, 4}).ValueOrDie()));
CHECK(filtered_batch.indices->Equals(lance::arrow::ToArray({2, 4}).ValueOrDie()));

bar = lance::arrow::ToArray({32, 32}).ValueOrDie();
struct_arr =
::arrow::StructArray::Make({bar}, {::arrow::field("value", ::arrow::int32())}).ValueOrDie();
CHECK(output->Equals(struct_arr));
auto expected = ::arrow::RecordBatch::Make(schema, 2, {bar});
CHECK(filtered_batch.batch->Equals(*expected));
}

TEST_CASE("label = cat or label = dog") {
auto expr =
or_(equal(field_ref("label"), literal("cat")), equal(field_ref("label"), literal("dog")));
auto labels =
lance::arrow::ToArray({"person", "dog", "cat", "car", "cat", "food", "hotdog"}).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("label", ::arrow::utf8())}), {labels});
auto schema = ::arrow::schema({::arrow::field("label", ::arrow::utf8())});
auto table = ::arrow::Table::Make(schema, {labels});

auto filter = Filter::Make(expr, TableScan::Make(*table)).ValueOrDie();
auto filtered_batch = filter->Next().ValueOrDie();
auto indices = filtered_batch.batch->GetColumnByName("indices");
auto output = filtered_batch.batch->GetColumnByName("values");
CHECK(indices->Equals(lance::arrow::ToArray({1, 2, 4}).ValueOrDie()));
CHECK(filtered_batch.indices->Equals(lance::arrow::ToArray({1, 2, 4}).ValueOrDie()));

labels = lance::arrow::ToArray({"dog", "cat", "cat"}).ValueOrDie();
auto struct_arr =
::arrow::StructArray::Make({labels}, {::arrow::field("label", ::arrow::utf8())}).ValueOrDie();
CHECK(output->Equals(struct_arr));
auto expected = ::arrow::RecordBatch::Make(schema, 3, {labels});
CHECK(filtered_batch.batch->Equals(*expected));
}
2 changes: 1 addition & 1 deletion cpp/src/lance/io/exec/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ::arrow::Result<ScanBatch> Scan::Next() {
fmt::print("Batch id: {} total batches={}\n", batch_id, reader_->metadata().num_batches());
if (batch_id >= reader_->metadata().num_batches()) {
// Reach EOF
return ScanBatch();
return ScanBatch::Null();
}

ARROW_ASSIGN_OR_RAISE(auto batch, reader_->ReadBatch(*schema_, batch_id, offset, batch_size_));
Expand Down
31 changes: 10 additions & 21 deletions cpp/src/lance/io/exec/take.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,19 @@ ::arrow::Result<std::unique_ptr<Take>> Take::Make(std::shared_ptr<FileReader> re
::arrow::Result<ScanBatch> Take::Next() {
ARROW_ASSIGN_OR_RAISE(auto filtered, child_->Next());
if (filtered.eof()) {
return ScanBatch{};
return ScanBatch::Null();
}
auto indices = filtered.batch->GetColumnByName("indices");
auto vals = filtered.batch->GetColumnByName("values");
if (!indices || indices->type_id() != ::arrow::Type::INT32 || !vals ||
vals->type_id() != ::arrow::Type::STRUCT) {
return ::arrow::Status::Invalid("Invalid data from filter node: batch=",
filtered.batch->ToString());
}
auto values = std::reinterpret_pointer_cast<::arrow::StructArray>(vals);
assert(filtered.indices);
const auto batch_id = filtered.batch_id;
if (!schema_ || schema_->fields().empty()) {
return ScanBatch{::arrow::RecordBatch::FromStructArray(vals).ValueOrDie(), filtered.batch_id};
return ScanBatch(filtered.batch, batch_id);
} else {
auto& batch_id = filtered.batch_id;
auto int32_indices = std::dynamic_pointer_cast<::arrow::Int32Array>(indices);
ARROW_ASSIGN_OR_RAISE(auto filtered_record_batch, ::arrow::RecordBatch::FromStructArray(vals));
ARROW_ASSIGN_OR_RAISE(auto batch, reader_->ReadBatch(*schema_, batch_id, int32_indices));
assert(filtered_record_batch->num_rows() == batch->num_rows());
fmt::print("Merge scan results: filtered={} extra={}\n",
filtered_record_batch->ToString(),
batch->schema()->ToString());
ARROW_ASSIGN_OR_RAISE(auto merged,
lance::arrow::MergeRecordBatches(filtered_record_batch, batch));
return ScanBatch{merged, filtered.batch_id};
ARROW_ASSIGN_OR_RAISE(auto rest_columns,
reader_->ReadBatch(*schema_, batch_id, filtered.indices));
assert(filtered.batch->num_rows() == rest_columns->num_rows());
ARROW_ASSIGN_OR_RAISE(auto merged_batch,
lance::arrow::MergeRecordBatches(filtered.batch, rest_columns));
return ScanBatch(merged_batch, filtered.batch_id);
}
}

Expand Down

0 comments on commit d3f97f9

Please sign in to comment.