From 5219708f4e2af96fc6812180217afe3e334efde9 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 12:15:53 -0800 Subject: [PATCH 01/20] sanity checks --- cpp/include/lance/arrow/dataset.h | 14 ++++++++++++++ cpp/src/lance/arrow/dataset.cc | 28 ++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index 1639ae1661..abc53ff856 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -130,6 +130,20 @@ class LanceDataset : public ::arrow::dataset::Dataset { ::arrow::Result> NewUpdate( const std::shared_ptr<::arrow::Field>& new_field) const; + /// Add all columns, except the "on" table, from an in-memory table. + /// + /// The algorithm follows the semantic of LEFT JOIN. The difference to LEFT JOIN + /// is that this function does not allow one row on the left ("this" dataset) + /// maps to two distinct rows on the right ("other"). + /// However, if a matched row on the right side does not exist, it allows to fill NULL. + /// + /// \param other the table to merge with this dataset. + /// \param on the column to be compared to. + /// This column must exist in both side and have the same data type.. + /// \return `::arrow::Status::OK` if success. + ::arrow::Result> AddColumns(const ::arrow::Table& other, + const std::string& on); + ::arrow::Result> ReplaceSchema( std::shared_ptr<::arrow::Schema> schema) const override; diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index 5723159882..5c26f0eb41 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -382,4 +382,32 @@ ::arrow::Result<::arrow::dataset::FragmentIterator> LanceDataset::GetFragmentsIm return ::arrow::MakeVectorIterator(fragments); } +::arrow::Result> LanceDataset::AddColumns(const ::arrow::Table& other, + const std::string& on) { + auto left_column = schema_->GetFieldByName(on); + if (left_column == nullptr) { + return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the dataset.", on)); + } + auto right_column = other.GetColumnByName(on); + if (right_column == nullptr) { + return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the table.", on)); + } + auto& left_type = left_column->type(); + auto& right_type = right_column->type(); + if (!::arrow::is_primitive(right_type->id()) && !::arrow::is_string(right_type->id())) { + return ::arrow::Status::Invalid("Only support primitive or string column type, got: ", + right_type->ToString()); + } + if (!left_type->Equals(right_type)) { + return ::arrow::Status::Invalid("LanceDataset::AddColumns: types are not equal: ", + left_type->ToString(), + " != ", + right_type->ToString()); + } + + // First step, hashing + + return ::arrow::Result>(); +} + } // namespace lance::arrow \ No newline at end of file From 5e8ca969626708b0fa2f535efae18ce1f475fcd0 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 14:23:18 -0800 Subject: [PATCH 02/20] build phase 1 hash map --- cpp/include/lance/arrow/dataset.h | 3 ++ cpp/include/lance/arrow/updater.h | 8 +-- cpp/src/lance/arrow/dataset.cc | 81 ++++++++++++++++++++++++++++--- cpp/src/lance/arrow/type.h | 3 ++ cpp/src/lance/arrow/updater.cc | 9 ++-- cpp/src/lance/encodings/encoder.h | 3 -- 6 files changed, 88 insertions(+), 19 deletions(-) diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index abc53ff856..35e055acae 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -130,6 +130,9 @@ class LanceDataset : public ::arrow::dataset::Dataset { ::arrow::Result> NewUpdate( const std::shared_ptr<::arrow::Field>& new_field) const; + ::arrow::Result> NewUpdate( + const std::shared_ptr<::arrow::Schema>& new_columns) const; + /// Add all columns, except the "on" table, from an in-memory table. /// /// The algorithm follows the semantic of LEFT JOIN. The difference to LEFT JOIN diff --git a/cpp/include/lance/arrow/updater.h b/cpp/include/lance/arrow/updater.h index 0c9df6c518..453d6518dd 100644 --- a/cpp/include/lance/arrow/updater.h +++ b/cpp/include/lance/arrow/updater.h @@ -74,13 +74,13 @@ class Updater { /// Make a new Updater /// /// \param dataset The dataset to be updated. - /// \param field the (new) column to update. + /// \param schema the (new) columns to update. /// \param projection_columns the columns to read from source dataset. /// /// \return an Updater if success. static ::arrow::Result> Make( std::shared_ptr dataset, - const std::shared_ptr<::arrow::Field>& field, + const std::shared_ptr<::arrow::Schema>& schema, const std::vector& projection_columns); /// PIMPL @@ -99,7 +99,7 @@ class Updater { /// parameters to build a Updater. class UpdaterBuilder { public: - UpdaterBuilder(std::shared_ptr dataset, std::shared_ptr<::arrow::Field> field); + UpdaterBuilder(std::shared_ptr dataset, std::shared_ptr<::arrow::Schema> schema); /// Set the projection columns from the source dataset. void Project(std::vector columns); @@ -109,7 +109,7 @@ class UpdaterBuilder { private: std::shared_ptr dataset_; - std::shared_ptr<::arrow::Field> field_; + std::shared_ptr<::arrow::Schema> schema_; std::vector projection_columns_; }; diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index 5c26f0eb41..01aa82b580 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -18,12 +18,16 @@ #include #include #include +#include #include #include #include #include #include +#include +#include +#include #include #include "lance/arrow/dataset_ext.h" @@ -313,8 +317,12 @@ DatasetVersion LanceDataset::version() const { return impl_->manifest->GetDatase ::arrow::Result> LanceDataset::NewUpdate( const std::shared_ptr<::arrow::Field>& new_field) const { - return std::make_shared(std::make_shared(*this), - std::move(new_field)); + return NewUpdate(::arrow::schema({new_field})); +} + +::arrow::Result> LanceDataset::NewUpdate( + const std::shared_ptr<::arrow::Schema>& new_columns) const { + return std::make_shared(std::make_shared(*this), new_columns); } ::arrow::Result> LanceDataset::AddColumn( @@ -382,8 +390,30 @@ ::arrow::Result<::arrow::dataset::FragmentIterator> LanceDataset::GetFragmentsIm return ::arrow::MakeVectorIterator(fragments); } +/// Build index map: key => {chunk_id, idx_in_chunk}. +/// +template ::CType> +::arrow::Result>> BuildHashChunkIndex( + const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) { + std::unordered_map> key_to_chunk_index; + for (int64_t chk = 0; chk < chunked_arr->num_chunks(); chk++) { + auto arr = std::dynamic_pointer_cast::ArrayType>( + chunked_arr->chunk(chk)); + for (int64_t idx = 0; idx < arr->length(); idx++) { + auto value = arr->Value(idx); + auto key = std::hash{}(value); + auto ret = key_to_chunk_index.emplace(key, std::make_tuple(chk, idx)); + if (!ret.second) { + return ::arrow::Status::IndexError("Duplicated key found: ", value); + } + } + } + return std::move(key_to_chunk_index); +} + ::arrow::Result> LanceDataset::AddColumns(const ::arrow::Table& other, const std::string& on) { + /// Sanity checks auto left_column = schema_->GetFieldByName(on); if (left_column == nullptr) { return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the dataset.", on)); @@ -392,12 +422,9 @@ ::arrow::Result> LanceDataset::AddColumns(const :: if (right_column == nullptr) { return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the table.", on)); } + auto& left_type = left_column->type(); auto& right_type = right_column->type(); - if (!::arrow::is_primitive(right_type->id()) && !::arrow::is_string(right_type->id())) { - return ::arrow::Status::Invalid("Only support primitive or string column type, got: ", - right_type->ToString()); - } if (!left_type->Equals(right_type)) { return ::arrow::Status::Invalid("LanceDataset::AddColumns: types are not equal: ", left_type->ToString(), @@ -405,7 +432,47 @@ ::arrow::Result> LanceDataset::AddColumns(const :: right_type->ToString()); } - // First step, hashing + // First phase, build hash table (in memory for simplicity) + ::arrow::Result>> map_build_result; + +#define BUILD_CHUNK_IDX(TypeId) \ + case TypeId: \ + map_build_result = \ + BuildHashChunkIndex::Type>(right_column); \ + break; + + switch (right_type->id()) { + BUILD_CHUNK_IDX(::arrow::Type::UINT8); + BUILD_CHUNK_IDX(::arrow::Type::INT8); + BUILD_CHUNK_IDX(::arrow::Type::UINT16); + BUILD_CHUNK_IDX(::arrow::Type::INT16); + BUILD_CHUNK_IDX(::arrow::Type::UINT32); + BUILD_CHUNK_IDX(::arrow::Type::INT32); + BUILD_CHUNK_IDX(::arrow::Type::UINT64); + BUILD_CHUNK_IDX(::arrow::Type::INT64); + // BUILD_CHUNK_IDX(::arrow::Type::HALF_FLOAT); + BUILD_CHUNK_IDX(::arrow::Type::FLOAT); + BUILD_CHUNK_IDX(::arrow::Type::DOUBLE); + case ::arrow::Type::STRING: + map_build_result = BuildHashChunkIndex<::arrow::StringType, std::string_view>(right_column); + break; + default: + return ::arrow::Status::Invalid("Only support primitive or string type, got: ", + right_type->ToString()); + } + +#undef BUILD_CHUNK_IDX + + if (!map_build_result.ok()) { + return map_build_result.status(); + } + auto hash_map = map_build_result.ValueUnsafe(); + + // Second phase + auto table_schema = other.schema(); + ARROW_ASSIGN_OR_RAISE(auto merged_schema, + table_schema->RemoveField(table_schema->GetFieldIndex(on))); + ARROW_ASSIGN_OR_RAISE(auto update_builder, NewUpdate(std::move(merged_schema))); return ::arrow::Result>(); } diff --git a/cpp/src/lance/arrow/type.h b/cpp/src/lance/arrow/type.h index 918f452480..37bdb0f964 100644 --- a/cpp/src/lance/arrow/type.h +++ b/cpp/src/lance/arrow/type.h @@ -26,6 +26,9 @@ #include #include +template +concept ArrowType = std::is_base_of<::arrow::DataType, T>::value; + template concept HasToString = requires(T t) { { t.ToString() } -> std::same_as; diff --git a/cpp/src/lance/arrow/updater.cc b/cpp/src/lance/arrow/updater.cc index 8830041a66..00291483f7 100644 --- a/cpp/src/lance/arrow/updater.cc +++ b/cpp/src/lance/arrow/updater.cc @@ -184,9 +184,8 @@ Updater::~Updater() {} ::arrow::Result> Updater::Make( std::shared_ptr dataset, - const std::shared_ptr<::arrow::Field>& field, + const std::shared_ptr<::arrow::Schema>& arrow_schema, const std::vector& projection_columns) { - auto arrow_schema = ::arrow::schema({field}); ARROW_ASSIGN_OR_RAISE(auto full_schema, dataset->impl_->manifest->schema()->Merge(*arrow_schema)); ARROW_ASSIGN_OR_RAISE(auto column_schema, full_schema->Project(*arrow_schema)); ARROW_ASSIGN_OR_RAISE(auto fragment_iter, dataset->GetFragments()); @@ -212,15 +211,15 @@ Updater::Updater(std::unique_ptr impl) : impl_(std::move(impl)) {} ::arrow::Result> Updater::Finish() { return impl_->Finish(); } UpdaterBuilder::UpdaterBuilder(std::shared_ptr source, - std::shared_ptr<::arrow::Field> field) - : dataset_(std::move(source)), field_(std::move(field)) {} + std::shared_ptr<::arrow::Schema> schema) + : dataset_(std::move(source)), schema_(std::move(schema)) {} void UpdaterBuilder::Project(std::vector columns) { projection_columns_ = std::move(columns); } ::arrow::Result> UpdaterBuilder::Finish() { - return Updater::Make(dataset_, field_, projection_columns_); + return Updater::Make(dataset_, schema_, projection_columns_); } } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/encodings/encoder.h b/cpp/src/lance/encodings/encoder.h index d12da012e4..91e744912b 100644 --- a/cpp/src/lance/encodings/encoder.h +++ b/cpp/src/lance/encodings/encoder.h @@ -33,9 +33,6 @@ class OutputStream; namespace lance::encodings { -template -concept ArrowType = std::is_base_of<::arrow::DataType, T>::value; - /// Encoding type Enum enum Encoding { NONE = 0, From 12c5cbf73744356df34eae54f75b9937cb9fb71f Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 15:06:56 -0800 Subject: [PATCH 03/20] move to hash merger --- cpp/src/lance/arrow/CMakeLists.txt | 2 + cpp/src/lance/arrow/dataset.cc | 48 +++++----------- cpp/src/lance/arrow/hash_merger.cc | 89 ++++++++++++++++++++++++++++++ cpp/src/lance/arrow/hash_merger.h | 48 ++++++++++++++++ 4 files changed, 153 insertions(+), 34 deletions(-) create mode 100644 cpp/src/lance/arrow/hash_merger.cc create mode 100644 cpp/src/lance/arrow/hash_merger.h diff --git a/cpp/src/lance/arrow/CMakeLists.txt b/cpp/src/lance/arrow/CMakeLists.txt index 8b26e54bc3..ed04d84a06 100644 --- a/cpp/src/lance/arrow/CMakeLists.txt +++ b/cpp/src/lance/arrow/CMakeLists.txt @@ -24,6 +24,8 @@ add_library( file_lance_ext.h fragment.cc fragment.h + hash_merger.cc + hash_merger.h scanner.cc stl.h type.cc diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index 01aa82b580..42cf1ca5f1 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -33,6 +33,7 @@ #include "lance/arrow/dataset_ext.h" #include "lance/arrow/file_lance.h" #include "lance/arrow/fragment.h" +#include "lance/arrow/hash_merger.h" #include "lance/arrow/updater.h" #include "lance/arrow/utils.h" #include "lance/format/manifest.h" @@ -433,47 +434,26 @@ ::arrow::Result> LanceDataset::AddColumns(const :: } // First phase, build hash table (in memory for simplicity) - ::arrow::Result>> map_build_result; - -#define BUILD_CHUNK_IDX(TypeId) \ - case TypeId: \ - map_build_result = \ - BuildHashChunkIndex::Type>(right_column); \ - break; - - switch (right_type->id()) { - BUILD_CHUNK_IDX(::arrow::Type::UINT8); - BUILD_CHUNK_IDX(::arrow::Type::INT8); - BUILD_CHUNK_IDX(::arrow::Type::UINT16); - BUILD_CHUNK_IDX(::arrow::Type::INT16); - BUILD_CHUNK_IDX(::arrow::Type::UINT32); - BUILD_CHUNK_IDX(::arrow::Type::INT32); - BUILD_CHUNK_IDX(::arrow::Type::UINT64); - BUILD_CHUNK_IDX(::arrow::Type::INT64); - // BUILD_CHUNK_IDX(::arrow::Type::HALF_FLOAT); - BUILD_CHUNK_IDX(::arrow::Type::FLOAT); - BUILD_CHUNK_IDX(::arrow::Type::DOUBLE); - case ::arrow::Type::STRING: - map_build_result = BuildHashChunkIndex<::arrow::StringType, std::string_view>(right_column); - break; - default: - return ::arrow::Status::Invalid("Only support primitive or string type, got: ", - right_type->ToString()); - } - -#undef BUILD_CHUNK_IDX - - if (!map_build_result.ok()) { - return map_build_result.status(); - } - auto hash_map = map_build_result.ValueUnsafe(); + auto merger = HashMerger(); + ARROW_RETURN_NOT_OK(merger.Build(other, on)); // Second phase auto table_schema = other.schema(); ARROW_ASSIGN_OR_RAISE(auto merged_schema, table_schema->RemoveField(table_schema->GetFieldIndex(on))); ARROW_ASSIGN_OR_RAISE(auto update_builder, NewUpdate(std::move(merged_schema))); + update_builder->Project({on}); + ARROW_ASSIGN_OR_RAISE(auto updater, update_builder->Finish()); + while (true) { + ARROW_ASSIGN_OR_RAISE(auto batch, updater->Next()); + if (!batch) { + break; + } + assert(batch->schema()->Equals(::arrow::schema({left_column}))); + auto index_arr = batch->GetColumnByName(on); + ARROW_ASSIGN_OR_RAISE(auto right_batch, merger.Collect(index_arr)); + } return ::arrow::Result>(); } diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc new file mode 100644 index 0000000000..f27cc997a5 --- /dev/null +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -0,0 +1,89 @@ +// Copyright 2022 Lance Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lance/arrow/hash_merger.h" + +#include + +namespace lance::arrow { + +/// Build index map: key => {chunk_id, idx_in_chunk}. +/// +template ::CType> +::arrow::Result>> BuildHashChunkIndex( + const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) { + std::unordered_map> key_to_chunk_index; + for (int64_t chk = 0; chk < chunked_arr->num_chunks(); chk++) { + auto arr = std::dynamic_pointer_cast::ArrayType>( + chunked_arr->chunk(chk)); + for (int64_t idx = 0; idx < arr->length(); idx++) { + auto value = arr->Value(idx); + auto key = std::hash{}(value); + auto ret = key_to_chunk_index.emplace(key, std::make_tuple(chk, idx)); + if (!ret.second) { + return ::arrow::Status::IndexError("Duplicated key found: ", value); + } + } + } + return std::move(key_to_chunk_index); +} + +::arrow::Status HashMerger::Build(const ::arrow::Table& table, const std::string& col_name) { + auto chunked_arr = table.GetColumnByName(col_name); + index_column_type_ = chunked_arr->type(); + + ::arrow::Result>> result; + +#define BUILD_CHUNK_IDX(TypeId) \ + case TypeId: \ + result = BuildHashChunkIndex::Type>(chunked_arr); \ + break; + + switch (index_column_type_->id()) { + BUILD_CHUNK_IDX(::arrow::Type::UINT8); + BUILD_CHUNK_IDX(::arrow::Type::INT8); + BUILD_CHUNK_IDX(::arrow::Type::UINT16); + BUILD_CHUNK_IDX(::arrow::Type::INT16); + BUILD_CHUNK_IDX(::arrow::Type::UINT32); + BUILD_CHUNK_IDX(::arrow::Type::INT32); + BUILD_CHUNK_IDX(::arrow::Type::UINT64); + BUILD_CHUNK_IDX(::arrow::Type::INT64); + // BUILD_CHUNK_IDX(::arrow::Type::HALF_FLOAT); + BUILD_CHUNK_IDX(::arrow::Type::FLOAT); + BUILD_CHUNK_IDX(::arrow::Type::DOUBLE); + case ::arrow::Type::STRING: + result = BuildHashChunkIndex<::arrow::StringType, std::string_view>(chunked_arr); + break; + default: + return ::arrow::Status::Invalid("Only support primitive or string type, got: ", + index_column_type_->ToString()); + } + + if (!result.ok()) { + return result.status(); + } + index_map_ = std::move(result.ValueOrDie()); + return ::arrow::Status::OK(); +} + +::arrow::Result> HashMerger::Collect( + const std::shared_ptr<::arrow::Array>& on_col) { + if (!on_col->type()->Equals(index_column_type_)) { + return ::arrow::Status::TypeError( + "Index column match mismatch: ", on_col->type()->ToString(), " != ", index_column_type_); + } + return ::arrow::Result>(); +} + +} // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/hash_merger.h b/cpp/src/lance/arrow/hash_merger.h new file mode 100644 index 0000000000..84ab123954 --- /dev/null +++ b/cpp/src/lance/arrow/hash_merger.h @@ -0,0 +1,48 @@ +// Copyright 2022 Lance Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "lance/arrow/type.h" + +namespace lance::arrow { + +class HashMerger { + public: + HashMerger() = default; + + /// Build a hash map on column specified by "col_name". + ::arrow::Status Build(const ::arrow::Table& table, const std::string& col_name); + + ::arrow::Result> Collect( + const std::shared_ptr<::arrow::Array>& on_col); + + private: + std::unordered_map> index_map_; + std::shared_ptr<::arrow::DataType> index_column_type_; +}; + +} // namespace lance::arrow From f26c6637d60c224ccf7351807762eee0ce1cae4f Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 15:36:48 -0800 Subject: [PATCH 04/20] change updater to update records --- cpp/include/lance/arrow/dataset.h | 6 ++++-- cpp/include/lance/arrow/updater.h | 2 ++ cpp/src/lance/arrow/dataset.cc | 8 +++++--- cpp/src/lance/arrow/hash_merger.cc | 21 ++++++++++++++++++--- cpp/src/lance/arrow/hash_merger.h | 14 +++++++++++--- cpp/src/lance/arrow/updater.cc | 18 ++++++++++++------ 6 files changed, 52 insertions(+), 17 deletions(-) diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index 35e055acae..5e8ac2d2af 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -144,8 +144,10 @@ class LanceDataset : public ::arrow::dataset::Dataset { /// \param on the column to be compared to. /// This column must exist in both side and have the same data type.. /// \return `::arrow::Status::OK` if success. - ::arrow::Result> AddColumns(const ::arrow::Table& other, - const std::string& on); + ::arrow::Result> AddColumns( + const ::arrow::Table& other, + const std::string& on, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); ::arrow::Result> ReplaceSchema( std::shared_ptr<::arrow::Schema> schema) const override; diff --git a/cpp/include/lance/arrow/updater.h b/cpp/include/lance/arrow/updater.h index 453d6518dd..bc3ccbc562 100644 --- a/cpp/include/lance/arrow/updater.h +++ b/cpp/include/lance/arrow/updater.h @@ -67,6 +67,8 @@ class Updater { /// The array must has the same length as the batch returned previously via `Next()`. ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::Array>& arr); + ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch); + /// Finish the update and returns a new version of dataset. ::arrow::Result> Finish(); diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index 42cf1ca5f1..aec8c81225 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -413,7 +413,8 @@ ::arrow::Result>> B } ::arrow::Result> LanceDataset::AddColumns(const ::arrow::Table& other, - const std::string& on) { + const std::string& on, + ::arrow::MemoryPool* pool) { /// Sanity checks auto left_column = schema_->GetFieldByName(on); if (left_column == nullptr) { @@ -434,8 +435,8 @@ ::arrow::Result> LanceDataset::AddColumns(const :: } // First phase, build hash table (in memory for simplicity) - auto merger = HashMerger(); - ARROW_RETURN_NOT_OK(merger.Build(other, on)); + auto merger = HashMerger(other, on, pool); + ARROW_RETURN_NOT_OK(merger.Build()); // Second phase auto table_schema = other.schema(); @@ -453,6 +454,7 @@ ::arrow::Result> LanceDataset::AddColumns(const :: assert(batch->schema()->Equals(::arrow::schema({left_column}))); auto index_arr = batch->GetColumnByName(on); ARROW_ASSIGN_OR_RAISE(auto right_batch, merger.Collect(index_arr)); + ARROW_RETURN_NOT_OK(updater->UpdateBatch(right_batch)); } return ::arrow::Result>(); } diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index f27cc997a5..614d5c4495 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -18,6 +18,13 @@ namespace lance::arrow { +HashMerger::HashMerger(const ::arrow::Table& table, + std::string index_column, + ::arrow::MemoryPool* pool) + : table_(table), column_name_(std::move(index_column)), pool_(pool) {} + +namespace { + /// Build index map: key => {chunk_id, idx_in_chunk}. /// template ::CType> @@ -39,8 +46,10 @@ ::arrow::Result>> B return std::move(key_to_chunk_index); } -::arrow::Status HashMerger::Build(const ::arrow::Table& table, const std::string& col_name) { - auto chunked_arr = table.GetColumnByName(col_name); +} // namespace + +::arrow::Status HashMerger::Build() { + auto chunked_arr = table_.GetColumnByName(column_name_); index_column_type_ = chunked_arr->type(); ::arrow::Result>> result; @@ -83,7 +92,13 @@ ::arrow::Result> HashMerger::Collect( return ::arrow::Status::TypeError( "Index column match mismatch: ", on_col->type()->ToString(), " != ", index_column_type_); } - return ::arrow::Result>(); + for (int i = 0; i < table_.num_columns(); i++) { + auto field = table_.field(i); + if (field->name() == column_name_) { + continue; + } + } + return ::arrow::Status::NotImplemented("not impl"); } } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/hash_merger.h b/cpp/src/lance/arrow/hash_merger.h index 84ab123954..212a91e59b 100644 --- a/cpp/src/lance/arrow/hash_merger.h +++ b/cpp/src/lance/arrow/hash_merger.h @@ -30,19 +30,27 @@ namespace lance::arrow { +/// A basic implementation of in-memory hash (join) merge. class HashMerger { public: - HashMerger() = default; + HashMerger() = delete; - /// Build a hash map on column specified by "col_name". - ::arrow::Status Build(const ::arrow::Table& table, const std::string& col_name); + explicit HashMerger(const ::arrow::Table& table, + std::string index_column, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + + /// Build a hash map on column specified by "column". + ::arrow::Status Build(); ::arrow::Result> Collect( const std::shared_ptr<::arrow::Array>& on_col); private: + const ::arrow::Table& table_; + std::string column_name_; std::unordered_map> index_map_; std::shared_ptr<::arrow::DataType> index_column_type_; + ::arrow::MemoryPool* pool_; }; } // namespace lance::arrow diff --git a/cpp/src/lance/arrow/updater.cc b/cpp/src/lance/arrow/updater.cc index 00291483f7..4f46256cf1 100644 --- a/cpp/src/lance/arrow/updater.cc +++ b/cpp/src/lance/arrow/updater.cc @@ -61,11 +61,13 @@ class Updater::Impl { ::arrow::Result> Next(); - ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::Array>& arr); + ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch); ::arrow::Result> Finish(); private: + friend class Updater; + auto data_dir() const { return dataset_->impl_->data_dir(); } const auto& fs() const { return dataset_->impl_->fs; } @@ -149,22 +151,21 @@ ::arrow::Result> Updater::Impl::Next() { return last_batch_; } -::arrow::Status Updater::Impl::UpdateBatch(const std::shared_ptr<::arrow::Array>& arr) { +::arrow::Status Updater::Impl::UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch) { // Sanity checks. if (!last_batch_) { return ::arrow::Status::IOError( "Did not read batch before update, did you call Updater::Next() before?"); } - if (last_batch_->num_rows() != arr->length()) { + if (last_batch_->num_rows() != batch->num_rows()) { return ::arrow::Status::IOError( fmt::format("Updater::Update: input size({}) != output size({})", last_batch_->num_rows(), - arr->length())); + batch->num_rows())); } assert(writer_); last_batch_.reset(); - auto batch = ::arrow::RecordBatch::Make(column_schema_->ToArrow(), arr->length(), {arr}); return writer_->Write(batch); } @@ -203,7 +204,12 @@ ::arrow::Result> Updater::Make( ::arrow::Result> Updater::Next() { return impl_->Next(); } ::arrow::Status Updater::UpdateBatch(const std::shared_ptr<::arrow::Array>& arr) { - return impl_->UpdateBatch(arr); + auto batch = ::arrow::RecordBatch::Make(impl_->column_schema_->ToArrow(), arr->length(), {arr}); + return UpdateBatch(batch); +} + +::arrow::Status Updater::UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch) { + return impl_->UpdateBatch(batch); } Updater::Updater(std::unique_ptr impl) : impl_(std::move(impl)) {} From 25c801641a6787aa60112487311c258df83445c9 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 16:06:36 -0800 Subject: [PATCH 05/20] add test --- cpp/src/lance/arrow/CMakeLists.txt | 1 + cpp/src/lance/arrow/hash_merger.cc | 1 + cpp/src/lance/arrow/hash_merger_test.cc | 68 +++++++++++++++++++++++++ 3 files changed, 70 insertions(+) create mode 100644 cpp/src/lance/arrow/hash_merger_test.cc diff --git a/cpp/src/lance/arrow/CMakeLists.txt b/cpp/src/lance/arrow/CMakeLists.txt index ed04d84a06..dac655a17e 100644 --- a/cpp/src/lance/arrow/CMakeLists.txt +++ b/cpp/src/lance/arrow/CMakeLists.txt @@ -41,6 +41,7 @@ add_lance_test(api_test) add_lance_test(arrow_dataset_test) add_lance_test(dataset_test) add_lance_test(fragment_test) +add_lance_test(hash_merger_test) add_lance_test(scanner_test) add_lance_test(type_test) add_lance_test(updater_test) diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index 614d5c4495..22efb9b4b8 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -98,6 +98,7 @@ ::arrow::Result> HashMerger::Collect( continue; } } + fmt::print("{}", fmt::ptr(pool_)); return ::arrow::Status::NotImplemented("not impl"); } diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc new file mode 100644 index 0000000000..a285500e2a --- /dev/null +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -0,0 +1,68 @@ +// Copyright 2022 Lance Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lance/arrow/hash_merger.h" + +#include +#include +#include +#include + +#include +#include +#include + +#include "lance/arrow/type.h" + +using lance::arrow::HashMerger; + +template +std::shared_ptr<::arrow::Table> MakeTable() { + std::vector::CType> keys; + typename ::arrow::TypeTraits::BuilderType keys_builder; + typename ::arrow::StringBuilder value_builder; + ::arrow::ArrayVector key_arrs, value_arrs; + for (int chunk = 0; chunk < 5; chunk++) { + for (int i = 0; i < 10; i++) { + typename ::arrow::TypeTraits::CType value = chunk * 10 + i; + CHECK(keys_builder.Append(value).ok()); + CHECK(value_builder.Append(fmt::format("{}", value)).ok()); + } + auto keys_arr = keys_builder.Finish().ValueOrDie(); + auto values_arr = value_builder.Finish().ValueOrDie(); + key_arrs.emplace_back(keys_arr); + value_arrs.emplace_back(values_arr); + } + auto keys_chunked_arr = std::make_shared<::arrow::ChunkedArray>(key_arrs); + auto values_chunked_arr = std::make_shared<::arrow::ChunkedArray>(value_arrs); + return ::arrow::Table::Make(::arrow::schema({::arrow::field("keys", std::make_shared()), + ::arrow::field("values", ::arrow::utf8())}), + {keys_chunked_arr, values_chunked_arr}); +} + +template +void TestBuildHashMap() { + auto table = MakeTable(); + + HashMerger merger(*table, "keys"); + CHECK(merger.Build().ok()); +} + +TEST_CASE("Build Hash") { + TestBuildHashMap<::arrow::UInt8Type>(); + TestBuildHashMap<::arrow::Int32Type>(); + TestBuildHashMap<::arrow::UInt64Type>(); + TestBuildHashMap<::arrow::FloatType>(); + TestBuildHashMap<::arrow::DoubleType>(); +} \ No newline at end of file From 3c97c439fb7d780b74cacd3b6ca374893e471ded Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:13:29 -0800 Subject: [PATCH 06/20] results --- cpp/include/lance/arrow/dataset.h | 2 +- cpp/src/lance/arrow/dataset.cc | 13 +- cpp/src/lance/arrow/hash_merger.cc | 161 ++++++++++++++++-------- cpp/src/lance/arrow/hash_merger.h | 18 ++- cpp/src/lance/arrow/hash_merger_test.cc | 17 ++- 5 files changed, 142 insertions(+), 69 deletions(-) diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index 5e8ac2d2af..272b946268 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -145,7 +145,7 @@ class LanceDataset : public ::arrow::dataset::Dataset { /// This column must exist in both side and have the same data type.. /// \return `::arrow::Status::OK` if success. ::arrow::Result> AddColumns( - const ::arrow::Table& other, + const std::shared_ptr<::arrow::Table>& other, const std::string& on, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index aec8c81225..c9314adcf0 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -412,15 +412,16 @@ ::arrow::Result>> B return std::move(key_to_chunk_index); } -::arrow::Result> LanceDataset::AddColumns(const ::arrow::Table& other, - const std::string& on, - ::arrow::MemoryPool* pool) { +::arrow::Result> LanceDataset::AddColumns( + const std::shared_ptr<::arrow::Table>& other, + const std::string& on, + ::arrow::MemoryPool* pool) { /// Sanity checks auto left_column = schema_->GetFieldByName(on); if (left_column == nullptr) { return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the dataset.", on)); } - auto right_column = other.GetColumnByName(on); + auto right_column = other->GetColumnByName(on); if (right_column == nullptr) { return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the table.", on)); } @@ -436,10 +437,10 @@ ::arrow::Result> LanceDataset::AddColumns(const :: // First phase, build hash table (in memory for simplicity) auto merger = HashMerger(other, on, pool); - ARROW_RETURN_NOT_OK(merger.Build()); + ARROW_RETURN_NOT_OK(merger.Init()); // Second phase - auto table_schema = other.schema(); + auto table_schema = other->schema(); ARROW_ASSIGN_OR_RAISE(auto merged_schema, table_schema->RemoveField(table_schema->GetFieldIndex(on))); ARROW_ASSIGN_OR_RAISE(auto update_builder, NewUpdate(std::move(merged_schema))); diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index 22efb9b4b8..e9bb5a92c2 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -14,75 +14,115 @@ #include "lance/arrow/hash_merger.h" +#include #include +#include + +#include "lance/arrow/stl.h" +#include "lance/arrow/type.h" + namespace lance::arrow { -HashMerger::HashMerger(const ::arrow::Table& table, - std::string index_column, - ::arrow::MemoryPool* pool) - : table_(table), column_name_(std::move(index_column)), pool_(pool) {} - -namespace { - -/// Build index map: key => {chunk_id, idx_in_chunk}. -/// -template ::CType> -::arrow::Result>> BuildHashChunkIndex( - const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) { - std::unordered_map> key_to_chunk_index; - for (int64_t chk = 0; chk < chunked_arr->num_chunks(); chk++) { - auto arr = std::dynamic_pointer_cast::ArrayType>( - chunked_arr->chunk(chk)); - for (int64_t idx = 0; idx < arr->length(); idx++) { - auto value = arr->Value(idx); - auto key = std::hash{}(value); - auto ret = key_to_chunk_index.emplace(key, std::make_tuple(chk, idx)); - if (!ret.second) { - return ::arrow::Status::IndexError("Duplicated key found: ", value); +class HashMerger::Impl { + public: + virtual ~Impl() = default; + + virtual void ComputeHash(const std::shared_ptr<::arrow::Array>& arr, + std::vector>* out) = 0; + + virtual ::arrow::Result> BuildHashChunkIndex( + const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) = 0; +}; + +template ::CType> +class TypedHashMerger : public HashMerger::Impl { + public: + void ComputeHash(const std::shared_ptr<::arrow::Array>& arr, + std::vector>* out) override { + auto hash_func = std::hash{}; + assert(out); + auto values = std::dynamic_pointer_cast::ArrayType>(arr); + assert(values); + out->reserve(values->length()); + out->clear(); + for (int i = 0; i < values->length(); ++i) { + if (values->IsNull(i)) { + out->emplace_back(std::nullopt); + } else { + auto value = values->Value(i); + out->emplace_back(hash_func(value)); } } } - return std::move(key_to_chunk_index); -} -} // namespace + ::arrow::Result> BuildHashChunkIndex( + const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) override { + std::unordered_map key_to_chunk_index; + int64_t index = 0; + std::vector> hashes; + for (const auto& chunk : chunked_arr->chunks()) { + ComputeHash(chunk, &hashes); + assert(chunk->length() == static_cast(hashes.size())); + for (std::size_t i = 0; i < hashes.size(); i++) { + const auto& key = hashes[i]; + if (key.has_value()) { + auto ret = key_to_chunk_index.emplace(key.value(), index); + if (!ret.second) { + auto values = + std::dynamic_pointer_cast::ArrayType>(chunk); + return ::arrow::Status::IndexError("Duplicate key found: ", values->Value(i)); + } + } + index++; + } + } + return std::move(key_to_chunk_index); + } +}; + +HashMerger::HashMerger(std::shared_ptr<::arrow::Table> table, + std::string index_column, + ::arrow::MemoryPool* pool) + : table_(std::move(table)), column_name_(std::move(index_column)), pool_(pool) {} -::arrow::Status HashMerger::Build() { - auto chunked_arr = table_.GetColumnByName(column_name_); - index_column_type_ = chunked_arr->type(); +HashMerger::~HashMerger() {} - ::arrow::Result>> result; +::arrow::Status HashMerger::Init() { + auto chunked_arr = table_->GetColumnByName(column_name_); + if (chunked_arr == nullptr) { + return ::arrow::Status::Invalid("index column ", column_name_, " does not exist"); + } + index_column_type_ = chunked_arr->type(); -#define BUILD_CHUNK_IDX(TypeId) \ - case TypeId: \ - result = BuildHashChunkIndex::Type>(chunked_arr); \ +#define BUILD_IMPL(TypeId) \ + case TypeId: \ + impl_ = std::unique_ptr( \ + new TypedHashMerger::Type>()); \ break; switch (index_column_type_->id()) { - BUILD_CHUNK_IDX(::arrow::Type::UINT8); - BUILD_CHUNK_IDX(::arrow::Type::INT8); - BUILD_CHUNK_IDX(::arrow::Type::UINT16); - BUILD_CHUNK_IDX(::arrow::Type::INT16); - BUILD_CHUNK_IDX(::arrow::Type::UINT32); - BUILD_CHUNK_IDX(::arrow::Type::INT32); - BUILD_CHUNK_IDX(::arrow::Type::UINT64); - BUILD_CHUNK_IDX(::arrow::Type::INT64); - // BUILD_CHUNK_IDX(::arrow::Type::HALF_FLOAT); - BUILD_CHUNK_IDX(::arrow::Type::FLOAT); - BUILD_CHUNK_IDX(::arrow::Type::DOUBLE); + BUILD_IMPL(::arrow::Type::UINT8); + BUILD_IMPL(::arrow::Type::INT8); + BUILD_IMPL(::arrow::Type::UINT16); + BUILD_IMPL(::arrow::Type::INT16); + BUILD_IMPL(::arrow::Type::UINT32); + BUILD_IMPL(::arrow::Type::INT32); + BUILD_IMPL(::arrow::Type::UINT64); + BUILD_IMPL(::arrow::Type::INT64); + // BUILD_IMPL(::arrow::Type::HALF_FLOAT); + BUILD_IMPL(::arrow::Type::FLOAT); + BUILD_IMPL(::arrow::Type::DOUBLE); case ::arrow::Type::STRING: - result = BuildHashChunkIndex<::arrow::StringType, std::string_view>(chunked_arr); + impl_ = std::unique_ptr(new TypedHashMerger<::arrow::StringType, std::string_view>()); break; default: return ::arrow::Status::Invalid("Only support primitive or string type, got: ", index_column_type_->ToString()); } - if (!result.ok()) { - return result.status(); - } - index_map_ = std::move(result.ValueOrDie()); + ARROW_ASSIGN_OR_RAISE(index_map_, impl_->BuildHashChunkIndex(chunked_arr)); + return ::arrow::Status::OK(); } @@ -92,14 +132,27 @@ ::arrow::Result> HashMerger::Collect( return ::arrow::Status::TypeError( "Index column match mismatch: ", on_col->type()->ToString(), " != ", index_column_type_); } - for (int i = 0; i < table_.num_columns(); i++) { - auto field = table_.field(i); - if (field->name() == column_name_) { - continue; + std::vector> hashes; + impl_->ComputeHash(on_col, &hashes); + std::vector indices; + std::vector nulls; + for (const auto& hvalue : hashes) { + if (hvalue.has_value()) { + auto it = index_map_.find(hvalue.value()); + if (it != index_map_.end()) { + indices.emplace_back(it->second); + nulls.emplace_back(false); + } else { + nulls.emplace_back(true); + } + } else { + nulls.emplace_back(true); } } - fmt::print("{}", fmt::ptr(pool_)); - return ::arrow::Status::NotImplemented("not impl"); + ARROW_ASSIGN_OR_RAISE(auto indices_arr, lance::arrow::ToArray(indices)); + ARROW_ASSIGN_OR_RAISE(auto datum, ::arrow::compute::Take(table_, indices_arr)); + assert(datum.table()); + return datum.table()->CombineChunksToBatch(pool_); } } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/hash_merger.h b/cpp/src/lance/arrow/hash_merger.h index 212a91e59b..27f9cb7413 100644 --- a/cpp/src/lance/arrow/hash_merger.h +++ b/cpp/src/lance/arrow/hash_merger.h @@ -35,22 +35,32 @@ class HashMerger { public: HashMerger() = delete; - explicit HashMerger(const ::arrow::Table& table, + /// HashMerger constructor. + explicit HashMerger(std::shared_ptr<::arrow::Table> table, std::string index_column, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + ~HashMerger(); + /// Build a hash map on column specified by "column". - ::arrow::Status Build(); + ::arrow::Status Init(); ::arrow::Result> Collect( const std::shared_ptr<::arrow::Array>& on_col); private: - const ::arrow::Table& table_; + std::shared_ptr<::arrow::Table> table_; std::string column_name_; - std::unordered_map> index_map_; + + class Impl; + std::unique_ptr impl_; + /// A map from `std::hash(key)` to the index (`int64_t`) in the table. + std::unordered_map index_map_; std::shared_ptr<::arrow::DataType> index_column_type_; ::arrow::MemoryPool* pool_; + + template + friend class TypedHashMerger; }; } // namespace lance::arrow diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc index a285500e2a..5eb425c5c6 100644 --- a/cpp/src/lance/arrow/hash_merger_test.cc +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -23,6 +23,7 @@ #include #include +#include "lance/arrow/stl.h" #include "lance/arrow/type.h" using lance::arrow::HashMerger; @@ -33,6 +34,7 @@ std::shared_ptr<::arrow::Table> MakeTable() { typename ::arrow::TypeTraits::BuilderType keys_builder; typename ::arrow::StringBuilder value_builder; ::arrow::ArrayVector key_arrs, value_arrs; + /// for (int chunk = 0; chunk < 5; chunk++) { for (int i = 0; i < 10; i++) { typename ::arrow::TypeTraits::CType value = chunk * 10 + i; @@ -55,14 +57,21 @@ template void TestBuildHashMap() { auto table = MakeTable(); - HashMerger merger(*table, "keys"); - CHECK(merger.Build().ok()); + HashMerger merger(table, "keys"); + CHECK(merger.Init().ok()); + + auto pk_arr = + lance::arrow::ToArray::CType>({0, 3, 5, 10, 20}).ValueOrDie(); + auto result_batch = merger.Collect(pk_arr).ValueOrDie(); + fmt::print("Result: {}\n", result_batch->ToString()); } -TEST_CASE("Build Hash") { +TEST_CASE("Hash merge with primitive keys") { TestBuildHashMap<::arrow::UInt8Type>(); TestBuildHashMap<::arrow::Int32Type>(); TestBuildHashMap<::arrow::UInt64Type>(); TestBuildHashMap<::arrow::FloatType>(); TestBuildHashMap<::arrow::DoubleType>(); -} \ No newline at end of file +} + +TEST_CASE("Hash merge with string keys") {} \ No newline at end of file From 8068ac30838579b4746881ee4ff5742597eeaa60 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:15:35 -0800 Subject: [PATCH 07/20] results --- cpp/src/lance/arrow/hash_merger.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index e9bb5a92c2..73bac07352 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -136,17 +136,18 @@ ::arrow::Result> HashMerger::Collect( impl_->ComputeHash(on_col, &hashes); std::vector indices; std::vector nulls; + ::arrow::Int64Builder indices_builder; + ARROW_RETURN_NOT_OK(indices_builder.Reserve(on_col->length())); for (const auto& hvalue : hashes) { if (hvalue.has_value()) { auto it = index_map_.find(hvalue.value()); if (it != index_map_.end()) { - indices.emplace_back(it->second); - nulls.emplace_back(false); + ARROW_RETURN_NOT_OK(indices_builder.Append(it->second)); } else { - nulls.emplace_back(true); + ARROW_RETURN_NOT_OK(indices_builder.AppendNull()); } } else { - nulls.emplace_back(true); + ARROW_RETURN_NOT_OK(indices_builder.AppendNull()); } } ARROW_ASSIGN_OR_RAISE(auto indices_arr, lance::arrow::ToArray(indices)); From b5a4e2c9db5b2036912613311c3094ef832a0c12 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:16:11 -0800 Subject: [PATCH 08/20] results --- cpp/src/lance/arrow/hash_merger_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc index 5eb425c5c6..825af4a22d 100644 --- a/cpp/src/lance/arrow/hash_merger_test.cc +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -61,7 +61,8 @@ void TestBuildHashMap() { CHECK(merger.Init().ok()); auto pk_arr = - lance::arrow::ToArray::CType>({0, 3, 5, 10, 20}).ValueOrDie(); + lance::arrow::ToArray::CType>({10, 20, 0, 5, 200, 32, 88}) + .ValueOrDie(); auto result_batch = merger.Collect(pk_arr).ValueOrDie(); fmt::print("Result: {}\n", result_batch->ToString()); } From 77d611a73f7eb55e9d07aa3a1a3c43785a32651a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:30:16 -0800 Subject: [PATCH 09/20] add test for merge on primitive types --- cpp/src/lance/arrow/hash_merger.cc | 4 +-- cpp/src/lance/arrow/hash_merger_test.cc | 39 ++++++++++++++++++++----- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index 73bac07352..7cce1474ec 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -134,8 +134,6 @@ ::arrow::Result> HashMerger::Collect( } std::vector> hashes; impl_->ComputeHash(on_col, &hashes); - std::vector indices; - std::vector nulls; ::arrow::Int64Builder indices_builder; ARROW_RETURN_NOT_OK(indices_builder.Reserve(on_col->length())); for (const auto& hvalue : hashes) { @@ -150,7 +148,7 @@ ::arrow::Result> HashMerger::Collect( ARROW_RETURN_NOT_OK(indices_builder.AppendNull()); } } - ARROW_ASSIGN_OR_RAISE(auto indices_arr, lance::arrow::ToArray(indices)); + ARROW_ASSIGN_OR_RAISE(auto indices_arr, indices_builder.Finish()); ARROW_ASSIGN_OR_RAISE(auto datum, ::arrow::compute::Take(table_, indices_arr)); assert(datum.table()); return datum.table()->CombineChunksToBatch(pool_); diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc index 825af4a22d..28ed3f39ff 100644 --- a/cpp/src/lance/arrow/hash_merger_test.cc +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -54,25 +54,48 @@ std::shared_ptr<::arrow::Table> MakeTable() { } template -void TestBuildHashMap() { +void TestMergeOnPrimitiveType() { auto table = MakeTable(); HashMerger merger(table, "keys"); CHECK(merger.Init().ok()); auto pk_arr = - lance::arrow::ToArray::CType>({10, 20, 0, 5, 200, 32, 88}) + lance::arrow::ToArray::CType>({10, 20, 0, 5, 120, 32, 88}) .ValueOrDie(); + + ::arrow::StringBuilder values_builder; + typename ::arrow::TypeTraits::BuilderType key_builder; + + CHECK(values_builder.AppendValues({"10", "20", "0", "5"}).ok()); + CHECK(values_builder.AppendNull().ok()); + CHECK(values_builder.Append("32").ok()); + CHECK(values_builder.AppendNull().ok()); + auto values_arr = values_builder.Finish().ValueOrDie(); + + CHECK(key_builder.AppendValues({10, 20, 0, 5}).ok()); + CHECK(key_builder.AppendNull().ok()); + CHECK(key_builder.Append(32).ok()); + CHECK(key_builder.AppendNull().ok()); + auto keys_arr = key_builder.Finish().ValueOrDie(); + auto result_batch = merger.Collect(pk_arr).ValueOrDie(); - fmt::print("Result: {}\n", result_batch->ToString()); + auto expected = + ::arrow::RecordBatch::Make(::arrow::schema({::arrow::field("keys", std::make_shared()), + ::arrow::field("values", ::arrow::utf8())}), + values_arr->length(), + {keys_arr, values_arr}); + CHECK(result_batch->Equals(*expected)); } TEST_CASE("Hash merge with primitive keys") { - TestBuildHashMap<::arrow::UInt8Type>(); - TestBuildHashMap<::arrow::Int32Type>(); - TestBuildHashMap<::arrow::UInt64Type>(); - TestBuildHashMap<::arrow::FloatType>(); - TestBuildHashMap<::arrow::DoubleType>(); + TestMergeOnPrimitiveType<::arrow::UInt8Type>(); + TestMergeOnPrimitiveType<::arrow::Int8Type>(); + TestMergeOnPrimitiveType<::arrow::UInt16Type>(); + TestMergeOnPrimitiveType<::arrow::Int32Type>(); + TestMergeOnPrimitiveType<::arrow::UInt64Type>(); + TestMergeOnPrimitiveType<::arrow::FloatType>(); + TestMergeOnPrimitiveType<::arrow::DoubleType>(); } TEST_CASE("Hash merge with string keys") {} \ No newline at end of file From 37d638964db94c2a30f0445275ebac046e053a33 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:36:27 -0800 Subject: [PATCH 10/20] better comments --- cpp/include/lance/arrow/dataset.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index 272b946268..27e97e7741 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -133,12 +133,12 @@ class LanceDataset : public ::arrow::dataset::Dataset { ::arrow::Result> NewUpdate( const std::shared_ptr<::arrow::Schema>& new_columns) const; - /// Add all columns, except the "on" table, from an in-memory table. + /// Add all columns from the table, except the "on" column. /// - /// The algorithm follows the semantic of LEFT JOIN. The difference to LEFT JOIN - /// is that this function does not allow one row on the left ("this" dataset) - /// maps to two distinct rows on the right ("other"). - /// However, if a matched row on the right side does not exist, it allows to fill NULL. + /// The algorithm follows the semantic of `LEFT JOIN` in SQL. + /// The difference to LEFT JOIN is that this function does not allow one row + /// on the left ("this" dataset) maps to two distinct rows on the right ("other"). + /// However, if it can not find a matched row on the right side, a NULL value is provided. /// /// \param other the table to merge with this dataset. /// \param on the column to be compared to. From 24930ecaa945300c75fc59bb5516631436b2504a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:40:09 -0800 Subject: [PATCH 11/20] make gcc happy --- cpp/src/lance/arrow/hash_merger.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index 7cce1474ec..b7e671f9ba 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -77,7 +77,7 @@ class TypedHashMerger : public HashMerger::Impl { index++; } } - return std::move(key_to_chunk_index); + return key_to_chunk_index; } }; From 80ce0d9a8775fa3615063bc665f7048d1d63fa81 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:57:29 -0800 Subject: [PATCH 12/20] test --- cpp/src/lance/arrow/hash_merger.cc | 11 ++++++- cpp/src/lance/arrow/hash_merger_test.cc | 39 ++++++++++++++++++------- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index b7e671f9ba..bf88d9faaf 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -151,7 +151,16 @@ ::arrow::Result> HashMerger::Collect( ARROW_ASSIGN_OR_RAISE(auto indices_arr, indices_builder.Finish()); ARROW_ASSIGN_OR_RAISE(auto datum, ::arrow::compute::Take(table_, indices_arr)); assert(datum.table()); - return datum.table()->CombineChunksToBatch(pool_); + auto table = datum.table(); + + // Drop the index column. + for (int i = 0; i < table->num_columns(); ++i) { + if (table->field(i)->name() == column_name_) { + ARROW_ASSIGN_OR_RAISE(table, table->RemoveColumn(i)); + break; + } + } + return table->CombineChunksToBatch(pool_); } } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc index 28ed3f39ff..2fc6df1a46 100644 --- a/cpp/src/lance/arrow/hash_merger_test.cc +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -34,7 +34,6 @@ std::shared_ptr<::arrow::Table> MakeTable() { typename ::arrow::TypeTraits::BuilderType keys_builder; typename ::arrow::StringBuilder value_builder; ::arrow::ArrayVector key_arrs, value_arrs; - /// for (int chunk = 0; chunk < 5; chunk++) { for (int i = 0; i < 10; i++) { typename ::arrow::TypeTraits::CType value = chunk * 10 + i; @@ -73,18 +72,11 @@ void TestMergeOnPrimitiveType() { CHECK(values_builder.AppendNull().ok()); auto values_arr = values_builder.Finish().ValueOrDie(); - CHECK(key_builder.AppendValues({10, 20, 0, 5}).ok()); - CHECK(key_builder.AppendNull().ok()); - CHECK(key_builder.Append(32).ok()); - CHECK(key_builder.AppendNull().ok()); - auto keys_arr = key_builder.Finish().ValueOrDie(); - auto result_batch = merger.Collect(pk_arr).ValueOrDie(); auto expected = - ::arrow::RecordBatch::Make(::arrow::schema({::arrow::field("keys", std::make_shared()), - ::arrow::field("values", ::arrow::utf8())}), + ::arrow::RecordBatch::Make(::arrow::schema({::arrow::field("values", ::arrow::utf8())}), values_arr->length(), - {keys_arr, values_arr}); + {values_arr}); CHECK(result_batch->Equals(*expected)); } @@ -92,10 +84,35 @@ TEST_CASE("Hash merge with primitive keys") { TestMergeOnPrimitiveType<::arrow::UInt8Type>(); TestMergeOnPrimitiveType<::arrow::Int8Type>(); TestMergeOnPrimitiveType<::arrow::UInt16Type>(); + TestMergeOnPrimitiveType<::arrow::Int16Type>(); TestMergeOnPrimitiveType<::arrow::Int32Type>(); + TestMergeOnPrimitiveType<::arrow::UInt32Type>(); TestMergeOnPrimitiveType<::arrow::UInt64Type>(); + TestMergeOnPrimitiveType<::arrow::Int64Type>(); TestMergeOnPrimitiveType<::arrow::FloatType>(); TestMergeOnPrimitiveType<::arrow::DoubleType>(); } -TEST_CASE("Hash merge with string keys") {} \ No newline at end of file +TEST_CASE("Hash merge with string keys") { + auto keys = lance::arrow::ToArray({"a", "b", "c", "d"}).ValueOrDie(); + auto values = lance::arrow::ToArray({1, 2, 3, 4}).ValueOrDie(); + auto schema = ::arrow::schema( + {::arrow::field("keys", ::arrow::utf8()), ::arrow::field("values", ::arrow::int32())}); + auto table = ::arrow::Table::Make(schema, {keys, values}); + + HashMerger merger(table, "keys"); + CHECK(merger.Init().ok()); + + auto pk_arr = lance::arrow::ToArray({"c", "d", "e", "f", "a"}).ValueOrDie(); + auto batch = merger.Collect(pk_arr).ValueOrDie(); + fmt::print("Batch is: {}\n", batch->ToString()); + CHECK(batch->num_columns() == 1); + ::arrow::Int32Builder builder; + CHECK(builder.AppendValues({3, 4}).ok()); + CHECK(builder.AppendNulls(2).ok()); + CHECK(builder.Append(0).ok()); + auto expected_values = builder.Finish().ValueOrDie(); + auto expected = ::arrow::RecordBatch::Make( + ::arrow::schema({::arrow::field("values", ::arrow::int32())}), 5, {expected_values}); + CHECK(batch->Equals(*expected)); +} \ No newline at end of file From 4a579001f264ba5371f7784c25c702753b9f926c Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 21:58:59 -0800 Subject: [PATCH 13/20] clean up imports --- cpp/src/lance/arrow/hash_merger.h | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/lance/arrow/hash_merger.h b/cpp/src/lance/arrow/hash_merger.h index 27f9cb7413..7782cd6290 100644 --- a/cpp/src/lance/arrow/hash_merger.h +++ b/cpp/src/lance/arrow/hash_merger.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include "lance/arrow/type.h" From 3ec88a2ac530cc8c2822577c36595691495258ba Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 22:21:31 -0800 Subject: [PATCH 14/20] fix test --- cpp/src/lance/arrow/hash_merger_test.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc index 2fc6df1a46..e9ee3a153f 100644 --- a/cpp/src/lance/arrow/hash_merger_test.cc +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -105,12 +105,11 @@ TEST_CASE("Hash merge with string keys") { auto pk_arr = lance::arrow::ToArray({"c", "d", "e", "f", "a"}).ValueOrDie(); auto batch = merger.Collect(pk_arr).ValueOrDie(); - fmt::print("Batch is: {}\n", batch->ToString()); CHECK(batch->num_columns() == 1); ::arrow::Int32Builder builder; CHECK(builder.AppendValues({3, 4}).ok()); CHECK(builder.AppendNulls(2).ok()); - CHECK(builder.Append(0).ok()); + CHECK(builder.Append(1).ok()); auto expected_values = builder.Finish().ValueOrDie(); auto expected = ::arrow::RecordBatch::Make( ::arrow::schema({::arrow::field("values", ::arrow::int32())}), 5, {expected_values}); From 4063567125888ab288d95badbb1ff9415a47f9d3 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 29 Nov 2022 22:24:39 -0800 Subject: [PATCH 15/20] finish updater --- cpp/src/lance/arrow/dataset.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index c9314adcf0..fe535e43d0 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -457,7 +457,7 @@ ::arrow::Result> LanceDataset::AddColumns( ARROW_ASSIGN_OR_RAISE(auto right_batch, merger.Collect(index_arr)); ARROW_RETURN_NOT_OK(updater->UpdateBatch(right_batch)); } - return ::arrow::Result>(); + return updater->Finish(); } } // namespace lance::arrow \ No newline at end of file From c01b6523a366304ee42cf5b6a28184a269905a8d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Wed, 30 Nov 2022 08:49:46 -0800 Subject: [PATCH 16/20] add test --- cpp/src/lance/arrow/dataset_test.cc | 33 +++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/cpp/src/lance/arrow/dataset_test.cc b/cpp/src/lance/arrow/dataset_test.cc index eb8334287d..f6bdfadfc7 100644 --- a/cpp/src/lance/arrow/dataset_test.cc +++ b/cpp/src/lance/arrow/dataset_test.cc @@ -283,4 +283,37 @@ TEST_CASE("Dataset add column with a function call") { ::arrow::field("doubles", ::arrow::float64())}), {ids, doubles}); CHECK(table2->Equals(*expected_table)); +} + +TEST_CASE("Dataset add columns with a table") { + auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); + auto values = ToArray({"one", "two", "three", "four", "five"}).ValueOrDie(); + auto schema = ::arrow::schema( + {::arrow::field("id", ::arrow::int32()), ::arrow::field("value", ::arrow::utf8())}); + auto table = ::arrow::Table::Make(schema, {ids, values}); + auto base_uri = WriteTable(table); + + auto fs = std::make_shared<::arrow::fs::LocalFileSystem>(); + auto dataset = lance::arrow::LanceDataset::Make(fs, base_uri).ValueOrDie(); + CHECK(dataset->version().version() == 1); + + auto added_ids = ToArray({5, 4, 3, 10, 12, 1}).ValueOrDie(); + auto added_values = ToArray({50, 40, 30, 100, 120, 10}).ValueOrDie(); + auto added_table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()), + ::arrow::field("new_value", ::arrow::int32())}), + {added_ids, added_values}); + auto new_dataset = dataset->AddColumns(added_table, "id").ValueOrDie(); + CHECK(new_dataset->version().version() == 2); + auto new_table = + new_dataset->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie(); + + // TODO: Plain array does not support null yet, so arr[1] = 0 instead of Null. + auto new_values = ToArray({10, 0, 30, 40, 50}).ValueOrDie(); + auto expected_table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()), + ::arrow::field("value", ::arrow::utf8()), + ::arrow::field("new_value", ::arrow::int32())}), + {ids, values, new_values}); + CHECK(new_table->Equals(*expected_table)); } \ No newline at end of file From f433d7c5b4e053f77aac202a9f129e23c392dd50 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Wed, 30 Nov 2022 09:03:10 -0800 Subject: [PATCH 17/20] better comments --- cpp/include/lance/arrow/dataset.h | 21 +++++++++++++++++++++ cpp/include/lance/arrow/updater.h | 2 ++ cpp/src/lance/arrow/dataset.cc | 21 --------------------- cpp/src/lance/arrow/hash_merger.h | 2 ++ 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index 27e97e7741..161b54f562 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -140,10 +140,31 @@ class LanceDataset : public ::arrow::dataset::Dataset { /// on the left ("this" dataset) maps to two distinct rows on the right ("other"). /// However, if it can not find a matched row on the right side, a NULL value is provided. /// + /// For example, + /// + /// \code + /// dataset (left) = { + /// "id": [1, 2, 3, 4], + /// "vals": ["a", "b", "c", "d"], + /// } + /// table (right) = { + /// "id": [5, 1, 10, 3, 8], + /// "attrs": [5.0, 1.0, 10.0, 3.0, 8.0], + /// } + /// + /// dataset.AddColumn(table, on="id") => + /// { + /// "id": [1, 2, 3, 4], + /// "vals": ["a", "b", "c", "d"], + /// "attrs": [1.0, Null, 3.0, Null], + /// } + /// \endcode + /// /// \param other the table to merge with this dataset. /// \param on the column to be compared to. /// This column must exist in both side and have the same data type.. /// \return `::arrow::Status::OK` if success. + /// ::arrow::Result> AddColumns( const std::shared_ptr<::arrow::Table>& other, const std::string& on, diff --git a/cpp/include/lance/arrow/updater.h b/cpp/include/lance/arrow/updater.h index bc3ccbc562..18a6624ea2 100644 --- a/cpp/include/lance/arrow/updater.h +++ b/cpp/include/lance/arrow/updater.h @@ -67,6 +67,8 @@ class Updater { /// The array must has the same length as the batch returned previously via `Next()`. ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::Array>& arr); + /// Update the values to new values, presented in a `RecordBatch`. + /// The batch must has the same length as the batch returned previously via `Next()`. ::arrow::Status UpdateBatch(const std::shared_ptr<::arrow::RecordBatch>& batch); /// Finish the update and returns a new version of dataset. diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index fe535e43d0..5c3968bb00 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -391,27 +391,6 @@ ::arrow::Result<::arrow::dataset::FragmentIterator> LanceDataset::GetFragmentsIm return ::arrow::MakeVectorIterator(fragments); } -/// Build index map: key => {chunk_id, idx_in_chunk}. -/// -template ::CType> -::arrow::Result>> BuildHashChunkIndex( - const std::shared_ptr<::arrow::ChunkedArray>& chunked_arr) { - std::unordered_map> key_to_chunk_index; - for (int64_t chk = 0; chk < chunked_arr->num_chunks(); chk++) { - auto arr = std::dynamic_pointer_cast::ArrayType>( - chunked_arr->chunk(chk)); - for (int64_t idx = 0; idx < arr->length(); idx++) { - auto value = arr->Value(idx); - auto key = std::hash{}(value); - auto ret = key_to_chunk_index.emplace(key, std::make_tuple(chk, idx)); - if (!ret.second) { - return ::arrow::Status::IndexError("Duplicated key found: ", value); - } - } - } - return std::move(key_to_chunk_index); -} - ::arrow::Result> LanceDataset::AddColumns( const std::shared_ptr<::arrow::Table>& other, const std::string& on, diff --git a/cpp/src/lance/arrow/hash_merger.h b/cpp/src/lance/arrow/hash_merger.h index 7782cd6290..9fbb1e0629 100644 --- a/cpp/src/lance/arrow/hash_merger.h +++ b/cpp/src/lance/arrow/hash_merger.h @@ -30,6 +30,7 @@ namespace lance::arrow { /// A basic implementation of in-memory hash (join) merge. +/// class HashMerger { public: HashMerger() = delete; @@ -44,6 +45,7 @@ class HashMerger { /// Build a hash map on column specified by "column". ::arrow::Status Init(); + /// Collect the batch records with the same keys in the column. ::arrow::Result> Collect( const std::shared_ptr<::arrow::Array>& on_col); From db6d628cc011aae7275bba70d18571e44a10f14c Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 1 Dec 2022 13:10:40 -0800 Subject: [PATCH 18/20] do not join float --- cpp/src/lance/arrow/hash_merger.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index bf88d9faaf..a0200ffb66 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -110,9 +110,10 @@ ::arrow::Status HashMerger::Init() { BUILD_IMPL(::arrow::Type::INT32); BUILD_IMPL(::arrow::Type::UINT64); BUILD_IMPL(::arrow::Type::INT64); - // BUILD_IMPL(::arrow::Type::HALF_FLOAT); - BUILD_IMPL(::arrow::Type::FLOAT); - BUILD_IMPL(::arrow::Type::DOUBLE); + case ::arrow::Type::HALF_FLOAT: + case ::arrow::Type::FLOAT: + case ::arrow::Type::DOUBLE: + return ::arrow::Status::Invalid("Do not support merge on floating points"); case ::arrow::Type::STRING: impl_ = std::unique_ptr(new TypedHashMerger<::arrow::StringType, std::string_view>()); break; From 3042752493e1f56340364cd327e9f63e05667ef6 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 1 Dec 2022 13:56:50 -0800 Subject: [PATCH 19/20] test --- cpp/src/lance/arrow/hash_merger_test.cc | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/cpp/src/lance/arrow/hash_merger_test.cc b/cpp/src/lance/arrow/hash_merger_test.cc index e9ee3a153f..0d9619c3ff 100644 --- a/cpp/src/lance/arrow/hash_merger_test.cc +++ b/cpp/src/lance/arrow/hash_merger_test.cc @@ -25,6 +25,7 @@ #include "lance/arrow/stl.h" #include "lance/arrow/type.h" +#include "lance/testing/json.h" using lance::arrow::HashMerger; @@ -89,10 +90,23 @@ TEST_CASE("Hash merge with primitive keys") { TestMergeOnPrimitiveType<::arrow::UInt32Type>(); TestMergeOnPrimitiveType<::arrow::UInt64Type>(); TestMergeOnPrimitiveType<::arrow::Int64Type>(); - TestMergeOnPrimitiveType<::arrow::FloatType>(); - TestMergeOnPrimitiveType<::arrow::DoubleType>(); } +template +void TestMergeOnFloatType() { + auto table = + lance::testing::TableFromJSON(::arrow::schema({::arrow::field("a", std::make_shared())}), + R"([{"a": 1.0}, {"a": 2.0}])") + .ValueOrDie(); + HashMerger merger(table, "a"); + CHECK(!merger.Init().ok()); +} + +TEST_CASE("Float keys are not supported") { + TestMergeOnFloatType<::arrow::FloatType>(); + TestMergeOnFloatType<::arrow::DoubleType>(); +}; + TEST_CASE("Hash merge with string keys") { auto keys = lance::arrow::ToArray({"a", "b", "c", "d"}).ValueOrDie(); auto values = lance::arrow::ToArray({1, 2, 3, 4}).ValueOrDie(); From 5ef309a8e43f8efee169baa49a33f24c3ff5c59a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 1 Dec 2022 14:45:11 -0800 Subject: [PATCH 20/20] address comments --- cpp/include/lance/arrow/dataset.h | 23 ++++++++++++++----- cpp/src/lance/arrow/dataset.cc | 34 +++++++++++++++++++---------- cpp/src/lance/arrow/dataset_test.cc | 2 +- cpp/src/lance/arrow/hash_merger.cc | 10 ++++----- cpp/src/lance/arrow/hash_merger.h | 2 +- 5 files changed, 46 insertions(+), 25 deletions(-) diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index 161b54f562..72c7fbc873 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -133,7 +133,7 @@ class LanceDataset : public ::arrow::dataset::Dataset { ::arrow::Result> NewUpdate( const std::shared_ptr<::arrow::Schema>& new_columns) const; - /// Add all columns from the table, except the "on" column. + /// Merge an in-memory table, except the "right_on" column. /// /// The algorithm follows the semantic of `LEFT JOIN` in SQL. /// The difference to LEFT JOIN is that this function does not allow one row @@ -160,13 +160,24 @@ class LanceDataset : public ::arrow::dataset::Dataset { /// } /// \endcode /// - /// \param other the table to merge with this dataset. - /// \param on the column to be compared to. - /// This column must exist in both side and have the same data type.. + /// \param right the table to merge with this dataset. + /// \param left_on the column in this dataset be compared to. + /// \param right_on the column in the table to be compared to. + /// This column must exist in both side and have the same data type. + /// \param pool memory pool /// \return `::arrow::Status::OK` if success. /// - ::arrow::Result> AddColumns( - const std::shared_ptr<::arrow::Table>& other, + ::arrow::Result> Merge( + const std::shared_ptr<::arrow::Table>& right, + const std::string& left_on, + const std::string& right_on, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + + /// Merge an in-memory table, both sides must have the same column specified by the "on" name. + /// + /// See `Merge(right, left_on, right_on, pool)` for details. + ::arrow::Result> Merge( + const std::shared_ptr<::arrow::Table>& right, const std::string& on, ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index 5c3968bb00..d65d29a211 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -391,18 +391,28 @@ ::arrow::Result<::arrow::dataset::FragmentIterator> LanceDataset::GetFragmentsIm return ::arrow::MakeVectorIterator(fragments); } -::arrow::Result> LanceDataset::AddColumns( +::arrow::Result> LanceDataset::Merge( const std::shared_ptr<::arrow::Table>& other, const std::string& on, ::arrow::MemoryPool* pool) { + return Merge(other, on, on, pool); +} + +::arrow::Result> LanceDataset::Merge( + const std::shared_ptr<::arrow::Table>& right, + const std::string& left_on, + const std::string& right_on, + ::arrow::MemoryPool* pool) { /// Sanity checks - auto left_column = schema_->GetFieldByName(on); + auto left_column = schema_->GetFieldByName(left_on); if (left_column == nullptr) { - return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the dataset.", on)); + return ::arrow::Status::Invalid( + fmt::format("Column {} does not exist in the dataset.", left_on)); } - auto right_column = other->GetColumnByName(on); + auto right_column = right->GetColumnByName(right_on); if (right_column == nullptr) { - return ::arrow::Status::Invalid(fmt::format("Column {} does not exist in the table.", on)); + return ::arrow::Status::Invalid( + fmt::format("Column {} does not exist in the table.", right_on)); } auto& left_type = left_column->type(); @@ -415,15 +425,15 @@ ::arrow::Result> LanceDataset::AddColumns( } // First phase, build hash table (in memory for simplicity) - auto merger = HashMerger(other, on, pool); + auto merger = HashMerger(right, right_on, pool); ARROW_RETURN_NOT_OK(merger.Init()); // Second phase - auto table_schema = other->schema(); - ARROW_ASSIGN_OR_RAISE(auto merged_schema, - table_schema->RemoveField(table_schema->GetFieldIndex(on))); - ARROW_ASSIGN_OR_RAISE(auto update_builder, NewUpdate(std::move(merged_schema))); - update_builder->Project({on}); + auto table_schema = right->schema(); + ARROW_ASSIGN_OR_RAISE(auto incoming_schema, + table_schema->RemoveField(table_schema->GetFieldIndex(right_on))); + ARROW_ASSIGN_OR_RAISE(auto update_builder, NewUpdate(std::move(incoming_schema))); + update_builder->Project({left_on}); ARROW_ASSIGN_OR_RAISE(auto updater, update_builder->Finish()); while (true) { @@ -432,7 +442,7 @@ ::arrow::Result> LanceDataset::AddColumns( break; } assert(batch->schema()->Equals(::arrow::schema({left_column}))); - auto index_arr = batch->GetColumnByName(on); + auto index_arr = batch->GetColumnByName(left_on); ARROW_ASSIGN_OR_RAISE(auto right_batch, merger.Collect(index_arr)); ARROW_RETURN_NOT_OK(updater->UpdateBatch(right_batch)); } diff --git a/cpp/src/lance/arrow/dataset_test.cc b/cpp/src/lance/arrow/dataset_test.cc index f6bdfadfc7..b99f465682 100644 --- a/cpp/src/lance/arrow/dataset_test.cc +++ b/cpp/src/lance/arrow/dataset_test.cc @@ -303,7 +303,7 @@ TEST_CASE("Dataset add columns with a table") { ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()), ::arrow::field("new_value", ::arrow::int32())}), {added_ids, added_values}); - auto new_dataset = dataset->AddColumns(added_table, "id").ValueOrDie(); + auto new_dataset = dataset->Merge(added_table, "id").ValueOrDie(); CHECK(new_dataset->version().version() == 2); auto new_table = new_dataset->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie(); diff --git a/cpp/src/lance/arrow/hash_merger.cc b/cpp/src/lance/arrow/hash_merger.cc index a0200ffb66..c47312afc3 100644 --- a/cpp/src/lance/arrow/hash_merger.cc +++ b/cpp/src/lance/arrow/hash_merger.cc @@ -128,15 +128,15 @@ ::arrow::Status HashMerger::Init() { } ::arrow::Result> HashMerger::Collect( - const std::shared_ptr<::arrow::Array>& on_col) { - if (!on_col->type()->Equals(index_column_type_)) { + const std::shared_ptr<::arrow::Array>& index_arr) { + if (!index_arr->type()->Equals(index_column_type_)) { return ::arrow::Status::TypeError( - "Index column match mismatch: ", on_col->type()->ToString(), " != ", index_column_type_); + "Index column match mismatch: ", index_arr->type()->ToString(), " != ", index_column_type_); } std::vector> hashes; - impl_->ComputeHash(on_col, &hashes); + impl_->ComputeHash(index_arr, &hashes); ::arrow::Int64Builder indices_builder; - ARROW_RETURN_NOT_OK(indices_builder.Reserve(on_col->length())); + ARROW_RETURN_NOT_OK(indices_builder.Reserve(index_arr->length())); for (const auto& hvalue : hashes) { if (hvalue.has_value()) { auto it = index_map_.find(hvalue.value()); diff --git a/cpp/src/lance/arrow/hash_merger.h b/cpp/src/lance/arrow/hash_merger.h index 9fbb1e0629..6c04d8b55a 100644 --- a/cpp/src/lance/arrow/hash_merger.h +++ b/cpp/src/lance/arrow/hash_merger.h @@ -47,7 +47,7 @@ class HashMerger { /// Collect the batch records with the same keys in the column. ::arrow::Result> Collect( - const std::shared_ptr<::arrow::Array>& on_col); + const std::shared_ptr<::arrow::Array>& index_arr); private: std::shared_ptr<::arrow::Table> table_;