From e8922f913d730c4ef7fc9ac84df64fe897471fdf Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 27 Oct 2022 16:34:55 -0700 Subject: [PATCH] Merge two Schemas (#263) --- cpp/src/lance/format/schema.cc | 89 +++++++++++++++++++++++++++-- cpp/src/lance/format/schema.h | 12 ++++ cpp/src/lance/format/schema_test.cc | 42 ++++++++++++++ 3 files changed, 139 insertions(+), 4 deletions(-) diff --git a/cpp/src/lance/format/schema.cc b/cpp/src/lance/format/schema.cc index 4034b5f530..6213033b60 100644 --- a/cpp/src/lance/format/schema.cc +++ b/cpp/src/lance/format/schema.cc @@ -42,7 +42,7 @@ namespace lance::format { Field::Field() : id_(-1), parent_(-1) {} Field::Field(const std::shared_ptr<::arrow::Field>& field) - : id_(0), + : id_(-1), parent_(-1), name_(field->name()), logical_type_(arrow::ToLogicalType(field->type()).ValueOrDie()), @@ -334,8 +334,10 @@ int32_t Field::id() const { return id_; } void Field::SetId(int32_t parent_id, int32_t* current_id) { parent_ = parent_id; - id_ = (*current_id); - *current_id += 1; + if (id_ < 0) { + id_ = (*current_id); + *current_id += 1; + } for (auto& child : children_) { child->SetId(id_, current_id); } @@ -383,6 +385,45 @@ std::shared_ptr Field::Project(const std::shared_ptr<::arrow::Field>& arr return new_field; } +::arrow::Result> Field::Merge(const ::arrow::Field& arrow_field) const { + if (name() != arrow_field.name()) { + return ::arrow::Status::Invalid( + "Attempt to merge two different fields: ", name(), "!=", arrow_field.name()); + } + auto self_type = type(); + if (self_type->id() != arrow_field.type()->id()) { + return ::arrow::Status::Invalid("Can not merge two fields with different types: ", + self_type->ToString(), + " != ", + arrow_field.type()->ToString()); + }; + auto new_field = Copy(true); + if (::arrow::is_list_like(self_type->id())) { + auto list_type = std::dynamic_pointer_cast<::arrow::ListType>(arrow_field.type()); + + auto item_field = field(0); + ARROW_ASSIGN_OR_RAISE(auto new_item_field, item_field->Merge(*list_type->value_field())); + new_field->children_[0] = new_item_field; + } else if (lance::arrow::is_struct(self_type)) { + auto struct_type = std::dynamic_pointer_cast<::arrow::StructType>(arrow_field.type()); + for (auto& arrow_child : struct_type->fields()) { + bool found = false; + for (std::size_t i = 0; i < new_field->children_.size(); ++i) { + if (new_field->children_[i]->name_ == arrow_child->name()) { + ARROW_ASSIGN_OR_RAISE(new_field->children_[i], + new_field->children_[i]->Merge(*arrow_child)); + found = true; + break; + } + } + if (!found) { + new_field->children_.emplace_back(std::make_shared(arrow_child)); + } + } + } + return new_field; +} + bool Field::Equals(const Field& other, bool check_id) const { if (check_id && (id_ != other.id_ || parent_ != other.parent_)) { return false; @@ -570,6 +611,27 @@ ::arrow::Result> Schema::Exclude(const Schema& other) co return excluded; } +::arrow::Result> Schema::Merge(const ::arrow::Schema& arrow_schema) const { + auto merged = std::make_shared(); + for (auto& field : fields_) { + auto arrow_field = arrow_schema.GetFieldByName(field->name()); + if (arrow_field) { + ARROW_ASSIGN_OR_RAISE(auto new_field, field->Merge(*arrow_field)); + merged->AddField(new_field); + } else { + merged->AddField(field); + } + } + for (auto& arrow_field : arrow_schema.fields()) { + if (!GetField(arrow_field->name())) { + merged->AddField(std::make_shared(arrow_field)); + } + } + // Assign to new IDs + merged->AssignIds(); + return merged; +} + void Schema::AddField(std::shared_ptr f) { fields_.emplace_back(f); } std::shared_ptr Schema::GetField(int32_t id) const { @@ -629,12 +691,31 @@ std::shared_ptr Schema::Copy() const { } void Schema::AssignIds() { - int cur_id = 0; + int cur_id = GetMaxId() + 1; for (auto& field : fields_) { field->SetId(-1, &cur_id); } } +int32_t Schema::GetMaxId() const { + class MaxIdVisitor : public FieldVisitor { + public: + ::arrow::Status Visit(std::shared_ptr field) override { + max_id_ = std::max(field->id(), max_id_); + for (auto& child : field->children_) { + ARROW_RETURN_NOT_OK(Visit(child)); + } + return ::arrow::Status::OK(); + } + int32_t max_id_ = -1; + }; + auto visitor = MaxIdVisitor(); + if (!visitor.VisitSchema(*this).ok()) { + fmt::print(stderr, "Error when collecting max ID"); + } + return visitor.max_id_; +} + bool Schema::RemoveField(int32_t id) { for (auto it = fields_.begin(); it != fields_.end(); ++it) { if ((*it)->id() == id) { diff --git a/cpp/src/lance/format/schema.h b/cpp/src/lance/format/schema.h index af10c68057..c857a1ef7b 100644 --- a/cpp/src/lance/format/schema.h +++ b/cpp/src/lance/format/schema.h @@ -78,6 +78,12 @@ class Schema final { /// \return The newly created schema, excluding any column in "other". ::arrow::Result> Exclude(const Schema& other) const; + /// Merge with new fields. + /// + /// \param arrow_schema the schema to be merged. + /// \return A newly merged schema. + ::arrow::Result> Merge(const ::arrow::Schema& arrow_schema) const; + /// Add a new parent field. void AddField(std::shared_ptr f); @@ -114,6 +120,9 @@ class Schema final { /// (Re-)Assign Field IDs to all the fields. void AssignIds(); + /// Get the max assigned ID. + int32_t GetMaxId() const; + /// Make a full copy of the schema. std::shared_ptr Copy() const; @@ -216,6 +225,9 @@ class Field final { /// Project an arrow field to this field. std::shared_ptr Project(const std::shared_ptr<::arrow::Field>& arrow_field) const; + /// Merge an arrow field with this field. + ::arrow::Result> Merge(const ::arrow::Field& arrow_field) const; + /// Load dictionary array from disk. ::arrow::Status LoadDictionary(std::shared_ptr<::arrow::io::RandomAccessFile> infile); diff --git a/cpp/src/lance/format/schema_test.cc b/cpp/src/lance/format/schema_test.cc index 3eaaeb15d2..7a800af5a1 100644 --- a/cpp/src/lance/format/schema_test.cc +++ b/cpp/src/lance/format/schema_test.cc @@ -26,6 +26,7 @@ #include "lance/testing/json.h" using lance::arrow::ToArray; +using lance::format::Schema; using lance::testing::MakeDataset; using lance::testing::TableFromJSON; @@ -217,4 +218,45 @@ TEST_CASE("Test schema metadata") { CHECK(dataset->schema()->metadata()); CHECK(dataset->schema()->metadata()->Get("k1").ValueOrDie() == "v1"); CHECK(dataset->schema()->metadata()->Get("k1").ValueOrDie() == "v1"); +} + +TEST_CASE("Test merge two schemas") { + auto base_schema = Schema(::arrow::schema( + {::arrow::field("a", ::arrow::int32()), ::arrow::field("b", ::arrow::utf8())})); + auto merged = + base_schema.Merge(*::arrow::schema({::arrow::field("c", ::arrow::list(::arrow::utf8()))})) + .ValueOrDie(); + CHECK(merged->GetField("a")->id() == 0); + CHECK(merged->GetField("b")->id() == 1); + CHECK(merged->GetField("c")->id() == 2); +} + +TEST_CASE("Test merge two structs") { + auto base_schema = Schema(::arrow::schema( + {::arrow::field("a", ::arrow::struct_({::arrow::field("b", ::arrow::int32())}))})); + auto schema_c = + ::arrow::schema({::arrow::field("a", + ::arrow::struct_({::arrow::field("b", ::arrow::int32()), + ::arrow::field("c", ::arrow::int64())}))}); + auto merged = base_schema.Merge(*schema_c).ValueOrDie(); + auto a = merged->GetField("a"); + CHECK(a->id() == 0); + auto expected_struct = ::arrow::struct_( + {::arrow::field("b", ::arrow::int32()), ::arrow::field("c", ::arrow::int64())}); + INFO("Expected type: " << expected_struct->ToString() << " Got: " << a->type()->ToString()); + CHECK(a->type()->Equals(*expected_struct)); + CHECK(merged->GetField("a.b")->id() == 1); + CHECK(merged->GetField("a.c")->id() == 2); +} + +TEST_CASE("Test merge two list of structs") { + auto base_schema = Schema(::arrow::schema({::arrow::field( + "a", ::arrow::list(::arrow::struct_({::arrow::field("b1", ::arrow::int32())})))})); + auto addon_schema = ::arrow::schema({::arrow::field( + "a", ::arrow::list(::arrow::struct_({::arrow::field("b2", ::arrow::utf8())})))}); + auto merged = base_schema.Merge(*addon_schema).ValueOrDie(); + CHECK(merged->ToArrow()->Equals(::arrow::schema( + {::arrow::field("a", + ::arrow::list(::arrow::struct_({::arrow::field("b1", ::arrow::int32()), + ::arrow::field("b2", ::arrow::utf8())})))}))); } \ No newline at end of file