diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index 1639ae1661..72c7fbc873 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -130,6 +130,57 @@ class LanceDataset : public ::arrow::dataset::Dataset { ::arrow::Result> NewUpdate( const std::shared_ptr<::arrow::Field>& new_field) const; + ::arrow::Result> NewUpdate( + const std::shared_ptr<::arrow::Schema>& new_columns) const; + + /// Merge an in-memory table, except the "right_on" column. + /// + /// The algorithm follows the semantic of `LEFT JOIN` in SQL. + /// The difference to LEFT JOIN is that this function does not allow one row + /// on the left ("this" dataset) maps to two distinct rows on the right ("other"). + /// However, if it can not find a matched row on the right side, a NULL value is provided. + /// + /// For example, + /// + /// \code + /// dataset (left) = { + /// "id": [1, 2, 3, 4], + /// "vals": ["a", "b", "c", "d"], + /// } + /// table (right) = { + /// "id": [5, 1, 10, 3, 8], + /// "attrs": [5.0, 1.0, 10.0, 3.0, 8.0], + /// } + /// + /// dataset.AddColumn(table, on="id") => + /// { + /// "id": [1, 2, 3, 4], + /// "vals": ["a", "b", "c", "d"], + /// "attrs": [1.0, Null, 3.0, Null], + /// } + /// \endcode + /// + /// \param right the table to merge with this dataset. + /// \param left_on the column in this dataset be compared to. + /// \param right_on the column in the table to be compared to. + /// This column must exist in both side and have the same data type. + /// \param pool memory pool + /// \return `::arrow::Status::OK` if success. + /// + ::arrow::Result> Merge( + const std::shared_ptr<::arrow::Table>& right, + const std::string& left_on, + const std::string& right_on, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + + /// Merge an in-memory table, both sides must have the same column specified by the "on" name. + /// + /// See `Merge(right, left_on, right_on, pool)` for details. + ::arrow::Result> Merge( + const std::shared_ptr<::arrow::Table>& right, + const std::string& on, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + ::arrow::Result> ReplaceSchema( std::shared_ptr<::arrow::Schema> schema) const override; diff --git a/cpp/include/lance/arrow/updater.h b/cpp/include/lance/arrow/updater.h index 0c9df6c518..18a6624ea2 100644 --- a/cpp/include/lance/arrow/updater.h +++ b/cpp/include/lance/arrow/updater.h @@ -67,6 +67,10 @@ class Updater { /// The array must has the same length as the batch returned previously via `Next()`. ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::Array>& arr); + /// Update the values to new values, presented in a `RecordBatch`. + /// The batch must has the same length as the batch returned previously via `Next()`. + ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch); + /// Finish the update and returns a new version of dataset. ::arrow::Result> Finish(); @@ -74,13 +78,13 @@ class Updater { /// Make a new Updater /// /// \param dataset The dataset to be updated. - /// \param field the (new) column to update. + /// \param schema the (new) columns to update. /// \param projection_columns the columns to read from source dataset. /// /// \return an Updater if success. static ::arrow::Result> Make( std::shared_ptr dataset, - const std::shared_ptr<::arrow::Field>& field, + const std::shared_ptr<::arrow::Schema>& schema, const std::vector& projection_columns); /// PIMPL @@ -99,7 +103,7 @@ class Updater { /// parameters to build a Updater. class UpdaterBuilder { public: - UpdaterBuilder(std::shared_ptr dataset, std::shared_ptr<::arrow::Field> field); + UpdaterBuilder(std::shared_ptr dataset, std::shared_ptr<::arrow::Schema> schema); /// Set the projection columns from the source dataset. void Project(std::vector columns); @@ -109,7 +113,7 @@ class UpdaterBuilder { private: std::shared_ptr dataset_; - std::shared_ptr<::arrow::Field> field_; + std::shared_ptr<::arrow::Schema> schema_; std::vector projection_columns_; }; diff --git a/cpp/src/lance/arrow/CMakeLists.txt b/cpp/src/lance/arrow/CMakeLists.txt index 8b26e54bc3..dac655a17e 100644 --- a/cpp/src/lance/arrow/CMakeLists.txt +++ b/cpp/src/lance/arrow/CMakeLists.txt @@ -24,6 +24,8 @@ add_library( file_lance_ext.h fragment.cc fragment.h + hash_merger.cc + hash_merger.h scanner.cc stl.h type.cc @@ -39,6 +41,7 @@ add_lance_test(api_test) add_lance_test(arrow_dataset_test) add_lance_test(dataset_test) add_lance_test(fragment_test) +add_lance_test(hash_merger_test) add_lance_test(scanner_test) add_lance_test(type_test) add_lance_test(updater_test) diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index 5723159882..d65d29a211 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -18,17 +18,22 @@ #include #include #include +#include #include #include #include #include #include +#include +#include +#include #include #include "lance/arrow/dataset_ext.h" #include "lance/arrow/file_lance.h" #include "lance/arrow/fragment.h" +#include "lance/arrow/hash_merger.h" #include "lance/arrow/updater.h" #include "lance/arrow/utils.h" #include "lance/format/manifest.h" @@ -313,8 +318,12 @@ DatasetVersion LanceDataset::version() const { return impl_->manifest->GetDatase ::arrow::Result> LanceDataset::NewUpdate( const std::shared_ptr<::arrow::Field>& new_field) const { - return std::make_shared(std::make_shared(*this), - std::move(new_field)); + return NewUpdate(::arrow::schema({new_field})); +} + +::arrow::Result> LanceDataset::NewUpdate( + const std::shared_ptr<::arrow::Schema>& new_columns) const { + return std::make_shared(std::make_shared(*this), new_columns); } ::arrow::Result> LanceDataset::AddColumn( @@ -382,4 +391,62 @@ ::arrow::Result<::arrow::dataset::FragmentIterator> LanceDataset::GetFragmentsIm return ::arrow::MakeVectorIterator(fragments); } +::arrow::Result> LanceDataset::Merge( + const std::shared_ptr<::arrow::Table>& other, + const std::string& on, + ::arrow::MemoryPool* pool) { + return Merge(other, on, on, pool); +} + +::arrow::Result> LanceDataset::Merge( + const std::shared_ptr<::arrow::Table>& right, + const std::string& left_on, + const std::string& right_on, + ::arrow::MemoryPool* pool) { + /// Sanity checks + auto left_column = schema_->GetFieldByName(left_on); + if (left_column == nullptr) { + return ::arrow::Status::Invalid( + fmt::format("Column {} does not exist in the dataset.", left_on)); + } + auto right_column = right->GetColumnByName(right_on); + if (right_column == nullptr) { + return ::arrow::Status::Invalid( + fmt::format("Column {} does not exist in the table.", right_on)); + } + + auto& left_type = left_column->type(); + auto& right_type = right_column->type(); + if (!left_type->Equals(right_type)) { + return ::arrow::Status::Invalid("LanceDataset::AddColumns: types are not equal: ", + left_type->ToString(), + " != ", + right_type->ToString()); + } + + // First phase, build hash table (in memory for simplicity) + auto merger = HashMerger(right, right_on, pool); + ARROW_RETURN_NOT_OK(merger.Init()); + + // Second phase + auto table_schema = right->schema(); + ARROW_ASSIGN_OR_RAISE(auto incoming_schema, + table_schema->RemoveField(table_schema->GetFieldIndex(right_on))); + ARROW_ASSIGN_OR_RAISE(auto update_builder, NewUpdate(std::move(incoming_schema))); + update_builder->Project({left_on}); + ARROW_ASSIGN_OR_RAISE(auto updater, update_builder->Finish()); + + while (true) { + ARROW_ASSIGN_OR_RAISE(auto batch, updater->Next()); + if (!batch) { + break; + } + assert(batch->schema()->Equals(::arrow::schema({left_column}))); + auto index_arr = batch->GetColumnByName(left_on); + ARROW_ASSIGN_OR_RAISE(auto right_batch, merger.Collect(index_arr)); + ARROW_RETURN_NOT_OK(updater->UpdateBatch(right_batch)); + } + return updater->Finish(); +} + } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/dataset_test.cc b/cpp/src/lance/arrow/dataset_test.cc index eb8334287d..b99f465682 100644 --- a/cpp/src/lance/arrow/dataset_test.cc +++ b/cpp/src/lance/arrow/dataset_test.cc @@ -283,4 +283,37 @@ TEST_CASE("Dataset add column with a function call") { ::arrow::field("doubles", ::arrow::float64())}), {ids, doubles}); CHECK(table2->Equals(*expected_table)); +} + +TEST_CASE("Dataset add columns with a table") { + auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); + auto values = ToArray({"one", "two", "three", "four", "five"}).ValueOrDie(); + auto schema = ::arrow::schema( + {::arrow::field("id", ::arrow::int32()), ::arrow::field("value", ::arrow::utf8())}); + auto table = ::arrow::Table::Make(schema, {ids, values}); + auto base_uri = WriteTable(table); + + auto fs = std::make_shared<::arrow::fs::LocalFileSystem>(); + auto dataset = lance::arrow::LanceDataset::Make(fs, base_uri).ValueOrDie(); + CHECK(dataset->version().version() == 1); + + auto added_ids = ToArray({5, 4, 3, 10, 12, 1}).ValueOrDie(); + auto added_values = ToArray({50, 40, 30, 100, 120, 10}).ValueOrDie(); + auto added_table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()), + ::arrow::field("new_value", ::arrow::int32())}), + {added_ids, added_values}); + auto new_dataset = dataset->Merge(added_table, "id").ValueOrDie(); + CHECK(new_dataset->version().version() == 2); + auto new_table = + new_dataset->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie(); + + // TODO: Plain array does not support null yet, so arr[1] = 0 instead of Null. + auto new_values = ToArray({10, 0, 30, 40, 50}).ValueOrDie(); + auto expected_table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()), + ::arrow::field("value", ::arrow::utf8()), + ::arrow::field("new_value", ::arrow::int32())}), + {ids, values, new_values}); + CHECK(new_table->Equals(*expected_table)); } \ No newline at end of file diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc new file mode 100644 index 0000000000..c47312afc3 --- /dev/null +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -0,0 +1,167 @@ +// Copyright 2022 Lance Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lance/arrow/hash_merger.h" + +#include +#include + +#include + +#include "lance/arrow/stl.h" +#include "lance/arrow/type.h" + +namespace lance::arrow { + +class HashMerger::Impl { + public: + virtual ~Impl() = default; + + virtual void ComputeHash(const std::shared_ptr<::arrow::Array>& arr, + std::vector>* out) = 0; + + virtual ::arrow::Result> BuildHashChunkIndex( + const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) = 0; +}; + +template ::CType> +class TypedHashMerger : public HashMerger::Impl { + public: + void ComputeHash(const std::shared_ptr<::arrow::Array>& arr, + std::vector>* out) override { + auto hash_func = std::hash{}; + assert(out); + auto values = std::dynamic_pointer_cast::ArrayType>(arr); + assert(values); + out->reserve(values->length()); + out->clear(); + for (int i = 0; i < values->length(); ++i) { + if (values->IsNull(i)) { + out->emplace_back(std::nullopt); + } else { + auto value = values->Value(i); + out->emplace_back(hash_func(value)); + } + } + } + + ::arrow::Result> BuildHashChunkIndex( + const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) override { + std::unordered_map key_to_chunk_index; + int64_t index = 0; + std::vector> hashes; + for (const auto& chunk : chunked_arr->chunks()) { + ComputeHash(chunk, &hashes); + assert(chunk->length() == static_cast(hashes.size())); + for (std::size_t i = 0; i < hashes.size(); i++) { + const auto& key = hashes[i]; + if (key.has_value()) { + auto ret = key_to_chunk_index.emplace(key.value(), index); + if (!ret.second) { + auto values = + std::dynamic_pointer_cast::ArrayType>(chunk); + return ::arrow::Status::IndexError("Duplicate key found: ", values->Value(i)); + } + } + index++; + } + } + return key_to_chunk_index; + } +}; + +HashMerger::HashMerger(std::shared_ptr<::arrow::Table> table, + std::string index_column, + ::arrow::MemoryPool* pool) + : table_(std::move(table)), column_name_(std::move(index_column)), pool_(pool) {} + +HashMerger::~HashMerger() {} + +::arrow::Status HashMerger::Init() { + auto chunked_arr = table_->GetColumnByName(column_name_); + if (chunked_arr == nullptr) { + return ::arrow::Status::Invalid("index column ", column_name_, " does not exist"); + } + index_column_type_ = chunked_arr->type(); + +#define BUILD_IMPL(TypeId) \ + case TypeId: \ + impl_ = std::unique_ptr( \ + new TypedHashMerger::Type>()); \ + break; + + switch (index_column_type_->id()) { + BUILD_IMPL(::arrow::Type::UINT8); + BUILD_IMPL(::arrow::Type::INT8); + BUILD_IMPL(::arrow::Type::UINT16); + BUILD_IMPL(::arrow::Type::INT16); + BUILD_IMPL(::arrow::Type::UINT32); + BUILD_IMPL(::arrow::Type::INT32); + BUILD_IMPL(::arrow::Type::UINT64); + BUILD_IMPL(::arrow::Type::INT64); + case ::arrow::Type::HALF_FLOAT: + case ::arrow::Type::FLOAT: + case ::arrow::Type::DOUBLE: + return ::arrow::Status::Invalid("Do not support merge on floating points"); + case ::arrow::Type::STRING: + impl_ = std::unique_ptr(new TypedHashMerger<::arrow::StringType, std::string_view>()); + break; + default: + return ::arrow::Status::Invalid("Only support primitive or string type, got: ", + index_column_type_->ToString()); + } + + ARROW_ASSIGN_OR_RAISE(index_map_, impl_->BuildHashChunkIndex(chunked_arr)); + + return ::arrow::Status::OK(); +} + +::arrow::Result> HashMerger::Collect( + const std::shared_ptr<::arrow::Array>& index_arr) { + if (!index_arr->type()->Equals(index_column_type_)) { + return ::arrow::Status::TypeError( + "Index column match mismatch: ", index_arr->type()->ToString(), " != ", index_column_type_); + } + std::vector> hashes; + impl_->ComputeHash(index_arr, &hashes); + ::arrow::Int64Builder indices_builder; + ARROW_RETURN_NOT_OK(indices_builder.Reserve(index_arr->length())); + for (const auto& hvalue : hashes) { + if (hvalue.has_value()) { + auto it = index_map_.find(hvalue.value()); + if (it != index_map_.end()) { + ARROW_RETURN_NOT_OK(indices_builder.Append(it->second)); + } else { + ARROW_RETURN_NOT_OK(indices_builder.AppendNull()); + } + } else { + ARROW_RETURN_NOT_OK(indices_builder.AppendNull()); + } + } + ARROW_ASSIGN_OR_RAISE(auto indices_arr, indices_builder.Finish()); + ARROW_ASSIGN_OR_RAISE(auto datum, ::arrow::compute::Take(table_, indices_arr)); + assert(datum.table()); + auto table = datum.table(); + + // Drop the index column. + for (int i = 0; i < table->num_columns(); ++i) { + if (table->field(i)->name() == column_name_) { + ARROW_ASSIGN_OR_RAISE(table, table->RemoveColumn(i)); + break; + } + } + return table->CombineChunksToBatch(pool_); +} + +} // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/hash_merger.h b/cpp/src/lance/arrow/hash_merger.h new file mode 100644 index 0000000000..6c04d8b55a --- /dev/null +++ b/cpp/src/lance/arrow/hash_merger.h @@ -0,0 +1,67 @@ +// Copyright 2022 Lance Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "lance/arrow/type.h" + +namespace lance::arrow { + +/// A basic implementation of in-memory hash (join) merge. +/// +class HashMerger { + public: + HashMerger() = delete; + + /// HashMerger constructor. + explicit HashMerger(std::shared_ptr<::arrow::Table> table, + std::string index_column, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + + ~HashMerger(); + + /// Build a hash map on column specified by "column". + ::arrow::Status Init(); + + /// Collect the batch records with the same keys in the column. + ::arrow::Result> Collect( + const std::shared_ptr<::arrow::Array>& index_arr); + + private: + std::shared_ptr<::arrow::Table> table_; + std::string column_name_; + + class Impl; + std::unique_ptr impl_; + /// A map from `std::hash(key)` to the index (`int64_t`) in the table. + std::unordered_map index_map_; + std::shared_ptr<::arrow::DataType> index_column_type_; + ::arrow::MemoryPool* pool_; + + template + friend class TypedHashMerger; +}; + +} // namespace lance::arrow diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc new file mode 100644 index 0000000000..0d9619c3ff --- /dev/null +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -0,0 +1,131 @@ +// Copyright 2022 Lance Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lance/arrow/hash_merger.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "lance/arrow/stl.h" +#include "lance/arrow/type.h" +#include "lance/testing/json.h" + +using lance::arrow::HashMerger; + +template +std::shared_ptr<::arrow::Table> MakeTable() { + std::vector::CType> keys; + typename ::arrow::TypeTraits::BuilderType keys_builder; + typename ::arrow::StringBuilder value_builder; + ::arrow::ArrayVector key_arrs, value_arrs; + for (int chunk = 0; chunk < 5; chunk++) { + for (int i = 0; i < 10; i++) { + typename ::arrow::TypeTraits::CType value = chunk * 10 + i; + CHECK(keys_builder.Append(value).ok()); + CHECK(value_builder.Append(fmt::format("{}", value)).ok()); + } + auto keys_arr = keys_builder.Finish().ValueOrDie(); + auto values_arr = value_builder.Finish().ValueOrDie(); + key_arrs.emplace_back(keys_arr); + value_arrs.emplace_back(values_arr); + } + auto keys_chunked_arr = std::make_shared<::arrow::ChunkedArray>(key_arrs); + auto values_chunked_arr = std::make_shared<::arrow::ChunkedArray>(value_arrs); + return ::arrow::Table::Make(::arrow::schema({::arrow::field("keys", std::make_shared()), + ::arrow::field("values", ::arrow::utf8())}), + {keys_chunked_arr, values_chunked_arr}); +} + +template +void TestMergeOnPrimitiveType() { + auto table = MakeTable(); + + HashMerger merger(table, "keys"); + CHECK(merger.Init().ok()); + + auto pk_arr = + lance::arrow::ToArray::CType>({10, 20, 0, 5, 120, 32, 88}) + .ValueOrDie(); + + ::arrow::StringBuilder values_builder; + typename ::arrow::TypeTraits::BuilderType key_builder; + + CHECK(values_builder.AppendValues({"10", "20", "0", "5"}).ok()); + CHECK(values_builder.AppendNull().ok()); + CHECK(values_builder.Append("32").ok()); + CHECK(values_builder.AppendNull().ok()); + auto values_arr = values_builder.Finish().ValueOrDie(); + + auto result_batch = merger.Collect(pk_arr).ValueOrDie(); + auto expected = + ::arrow::RecordBatch::Make(::arrow::schema({::arrow::field("values", ::arrow::utf8())}), + values_arr->length(), + {values_arr}); + CHECK(result_batch->Equals(*expected)); +} + +TEST_CASE("Hash merge with primitive keys") { + TestMergeOnPrimitiveType<::arrow::UInt8Type>(); + TestMergeOnPrimitiveType<::arrow::Int8Type>(); + TestMergeOnPrimitiveType<::arrow::UInt16Type>(); + TestMergeOnPrimitiveType<::arrow::Int16Type>(); + TestMergeOnPrimitiveType<::arrow::Int32Type>(); + TestMergeOnPrimitiveType<::arrow::UInt32Type>(); + TestMergeOnPrimitiveType<::arrow::UInt64Type>(); + TestMergeOnPrimitiveType<::arrow::Int64Type>(); +} + +template +void TestMergeOnFloatType() { + auto table = + lance::testing::TableFromJSON(::arrow::schema({::arrow::field("a", std::make_shared())}), + R"([{"a": 1.0}, {"a": 2.0}])") + .ValueOrDie(); + HashMerger merger(table, "a"); + CHECK(!merger.Init().ok()); +} + +TEST_CASE("Float keys are not supported") { + TestMergeOnFloatType<::arrow::FloatType>(); + TestMergeOnFloatType<::arrow::DoubleType>(); +}; + +TEST_CASE("Hash merge with string keys") { + auto keys = lance::arrow::ToArray({"a", "b", "c", "d"}).ValueOrDie(); + auto values = lance::arrow::ToArray({1, 2, 3, 4}).ValueOrDie(); + auto schema = ::arrow::schema( + {::arrow::field("keys", ::arrow::utf8()), ::arrow::field("values", ::arrow::int32())}); + auto table = ::arrow::Table::Make(schema, {keys, values}); + + HashMerger merger(table, "keys"); + CHECK(merger.Init().ok()); + + auto pk_arr = lance::arrow::ToArray({"c", "d", "e", "f", "a"}).ValueOrDie(); + auto batch = merger.Collect(pk_arr).ValueOrDie(); + CHECK(batch->num_columns() == 1); + ::arrow::Int32Builder builder; + CHECK(builder.AppendValues({3, 4}).ok()); + CHECK(builder.AppendNulls(2).ok()); + CHECK(builder.Append(1).ok()); + auto expected_values = builder.Finish().ValueOrDie(); + auto expected = ::arrow::RecordBatch::Make( + ::arrow::schema({::arrow::field("values", ::arrow::int32())}), 5, {expected_values}); + CHECK(batch->Equals(*expected)); +} \ No newline at end of file diff --git a/cpp/src/lance/arrow/type.h b/cpp/src/lance/arrow/type.h index 918f452480..37bdb0f964 100644 --- a/cpp/src/lance/arrow/type.h +++ b/cpp/src/lance/arrow/type.h @@ -26,6 +26,9 @@ #include #include +template +concept ArrowType = std::is_base_of<::arrow::DataType, T>::value; + template concept HasToString = requires(T t) { { t.ToString() } -> std::same_as; diff --git a/cpp/src/lance/arrow/updater.cc b/cpp/src/lance/arrow/updater.cc index 8830041a66..4f46256cf1 100644 --- a/cpp/src/lance/arrow/updater.cc +++ b/cpp/src/lance/arrow/updater.cc @@ -61,11 +61,13 @@ class Updater::Impl { ::arrow::Result> Next(); - ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::Array>& arr); + ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch); ::arrow::Result> Finish(); private: + friend class Updater; + auto data_dir() const { return dataset_->impl_->data_dir(); } const auto& fs() const { return dataset_->impl_->fs; } @@ -149,22 +151,21 @@ ::arrow::Result> Updater::Impl::Next() { return last_batch_; } -::arrow::Status Updater::Impl::UpdateBatch(const std::shared_ptr<::arrow::Array>& arr) { +::arrow::Status Updater::Impl::UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch) { // Sanity checks. if (!last_batch_) { return ::arrow::Status::IOError( "Did not read batch before update, did you call Updater::Next() before?"); } - if (last_batch_->num_rows() != arr->length()) { + if (last_batch_->num_rows() != batch->num_rows()) { return ::arrow::Status::IOError( fmt::format("Updater::Update: input size({}) != output size({})", last_batch_->num_rows(), - arr->length())); + batch->num_rows())); } assert(writer_); last_batch_.reset(); - auto batch = ::arrow::RecordBatch::Make(column_schema_->ToArrow(), arr->length(), {arr}); return writer_->Write(batch); } @@ -184,9 +185,8 @@ Updater::~Updater() {} ::arrow::Result> Updater::Make( std::shared_ptr dataset, - const std::shared_ptr<::arrow::Field>& field, + const std::shared_ptr<::arrow::Schema>& arrow_schema, const std::vector& projection_columns) { - auto arrow_schema = ::arrow::schema({field}); ARROW_ASSIGN_OR_RAISE(auto full_schema, dataset->impl_->manifest->schema()->Merge(*arrow_schema)); ARROW_ASSIGN_OR_RAISE(auto column_schema, full_schema->Project(*arrow_schema)); ARROW_ASSIGN_OR_RAISE(auto fragment_iter, dataset->GetFragments()); @@ -204,7 +204,12 @@ ::arrow::Result> Updater::Make( ::arrow::Result> Updater::Next() { return impl_->Next(); } ::arrow::Status Updater::UpdateBatch(const std::shared_ptr<::arrow::Array>& arr) { - return impl_->UpdateBatch(arr); + auto batch = ::arrow::RecordBatch::Make(impl_->column_schema_->ToArrow(), arr->length(), {arr}); + return UpdateBatch(batch); +} + +::arrow::Status Updater::UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch) { + return impl_->UpdateBatch(batch); } Updater::Updater(std::unique_ptr impl) : impl_(std::move(impl)) {} @@ -212,15 +217,15 @@ Updater::Updater(std::unique_ptr impl) : impl_(std::move(impl)) {} ::arrow::Result> Updater::Finish() { return impl_->Finish(); } UpdaterBuilder::UpdaterBuilder(std::shared_ptr source, - std::shared_ptr<::arrow::Field> field) - : dataset_(std::move(source)), field_(std::move(field)) {} + std::shared_ptr<::arrow::Schema> schema) + : dataset_(std::move(source)), schema_(std::move(schema)) {} void UpdaterBuilder::Project(std::vector columns) { projection_columns_ = std::move(columns); } ::arrow::Result> UpdaterBuilder::Finish() { - return Updater::Make(dataset_, field_, projection_columns_); + return Updater::Make(dataset_, schema_, projection_columns_); } } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/encodings/encoder.h b/cpp/src/lance/encodings/encoder.h index d12da012e4..91e744912b 100644 --- a/cpp/src/lance/encodings/encoder.h +++ b/cpp/src/lance/encodings/encoder.h @@ -33,9 +33,6 @@ class OutputStream; namespace lance::encodings { -template -concept ArrowType = std::is_base_of<::arrow::DataType, T>::value; - /// Encoding type Enum enum Encoding { NONE = 0,