Skip to content

Commit

Permalink
Merge two Schemas (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Oct 27, 2022
1 parent 107c6cb commit e8922f9
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 4 deletions.
89 changes: 85 additions & 4 deletions cpp/src/lance/format/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace lance::format {
Field::Field() : id_(-1), parent_(-1) {}

Field::Field(const std::shared_ptr<::arrow::Field>& field)
: id_(0),
: id_(-1),
parent_(-1),
name_(field->name()),
logical_type_(arrow::ToLogicalType(field->type()).ValueOrDie()),
Expand Down Expand Up @@ -334,8 +334,10 @@ int32_t Field::id() const { return id_; }

void Field::SetId(int32_t parent_id, int32_t* current_id) {
parent_ = parent_id;
id_ = (*current_id);
*current_id += 1;
if (id_ < 0) {
id_ = (*current_id);
*current_id += 1;
}
for (auto& child : children_) {
child->SetId(id_, current_id);
}
Expand Down Expand Up @@ -383,6 +385,45 @@ std::shared_ptr<Field> Field::Project(const std::shared_ptr<::arrow::Field>& arr
return new_field;
}

::arrow::Result<std::shared_ptr<Field>> Field::Merge(const ::arrow::Field& arrow_field) const {
if (name() != arrow_field.name()) {
return ::arrow::Status::Invalid(
"Attempt to merge two different fields: ", name(), "!=", arrow_field.name());
}
auto self_type = type();
if (self_type->id() != arrow_field.type()->id()) {
return ::arrow::Status::Invalid("Can not merge two fields with different types: ",
self_type->ToString(),
" != ",
arrow_field.type()->ToString());
};
auto new_field = Copy(true);
if (::arrow::is_list_like(self_type->id())) {
auto list_type = std::dynamic_pointer_cast<::arrow::ListType>(arrow_field.type());

auto item_field = field(0);
ARROW_ASSIGN_OR_RAISE(auto new_item_field, item_field->Merge(*list_type->value_field()));
new_field->children_[0] = new_item_field;
} else if (lance::arrow::is_struct(self_type)) {
auto struct_type = std::dynamic_pointer_cast<::arrow::StructType>(arrow_field.type());
for (auto& arrow_child : struct_type->fields()) {
bool found = false;
for (std::size_t i = 0; i < new_field->children_.size(); ++i) {
if (new_field->children_[i]->name_ == arrow_child->name()) {
ARROW_ASSIGN_OR_RAISE(new_field->children_[i],
new_field->children_[i]->Merge(*arrow_child));
found = true;
break;
}
}
if (!found) {
new_field->children_.emplace_back(std::make_shared<Field>(arrow_child));
}
}
}
return new_field;
}

bool Field::Equals(const Field& other, bool check_id) const {
if (check_id && (id_ != other.id_ || parent_ != other.parent_)) {
return false;
Expand Down Expand Up @@ -570,6 +611,27 @@ ::arrow::Result<std::shared_ptr<Schema>> Schema::Exclude(const Schema& other) co
return excluded;
}

::arrow::Result<std::shared_ptr<Schema>> Schema::Merge(const ::arrow::Schema& arrow_schema) const {
auto merged = std::make_shared<Schema>();
for (auto& field : fields_) {
auto arrow_field = arrow_schema.GetFieldByName(field->name());
if (arrow_field) {
ARROW_ASSIGN_OR_RAISE(auto new_field, field->Merge(*arrow_field));
merged->AddField(new_field);
} else {
merged->AddField(field);
}
}
for (auto& arrow_field : arrow_schema.fields()) {
if (!GetField(arrow_field->name())) {
merged->AddField(std::make_shared<Field>(arrow_field));
}
}
// Assign to new IDs
merged->AssignIds();
return merged;
}

void Schema::AddField(std::shared_ptr<Field> f) { fields_.emplace_back(f); }

std::shared_ptr<Field> Schema::GetField(int32_t id) const {
Expand Down Expand Up @@ -629,12 +691,31 @@ std::shared_ptr<Schema> Schema::Copy() const {
}

void Schema::AssignIds() {
int cur_id = 0;
int cur_id = GetMaxId() + 1;
for (auto& field : fields_) {
field->SetId(-1, &cur_id);
}
}

int32_t Schema::GetMaxId() const {
class MaxIdVisitor : public FieldVisitor {
public:
::arrow::Status Visit(std::shared_ptr<Field> field) override {
max_id_ = std::max(field->id(), max_id_);
for (auto& child : field->children_) {
ARROW_RETURN_NOT_OK(Visit(child));
}
return ::arrow::Status::OK();
}
int32_t max_id_ = -1;
};
auto visitor = MaxIdVisitor();
if (!visitor.VisitSchema(*this).ok()) {
fmt::print(stderr, "Error when collecting max ID");
}
return visitor.max_id_;
}

bool Schema::RemoveField(int32_t id) {
for (auto it = fields_.begin(); it != fields_.end(); ++it) {
if ((*it)->id() == id) {
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/lance/format/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ class Schema final {
/// \return The newly created schema, excluding any column in "other".
::arrow::Result<std::shared_ptr<Schema>> Exclude(const Schema& other) const;

/// Merge with new fields.
///
/// \param arrow_schema the schema to be merged.
/// \return A newly merged schema.
::arrow::Result<std::shared_ptr<Schema>> Merge(const ::arrow::Schema& arrow_schema) const;

/// Add a new parent field.
void AddField(std::shared_ptr<Field> f);

Expand Down Expand Up @@ -114,6 +120,9 @@ class Schema final {
/// (Re-)Assign Field IDs to all the fields.
void AssignIds();

/// Get the max assigned ID.
int32_t GetMaxId() const;

/// Make a full copy of the schema.
std::shared_ptr<Schema> Copy() const;

Expand Down Expand Up @@ -216,6 +225,9 @@ class Field final {
/// Project an arrow field to this field.
std::shared_ptr<Field> Project(const std::shared_ptr<::arrow::Field>& arrow_field) const;

/// Merge an arrow field with this field.
::arrow::Result<std::shared_ptr<Field>> Merge(const ::arrow::Field& arrow_field) const;

/// Load dictionary array from disk.
::arrow::Status LoadDictionary(std::shared_ptr<::arrow::io::RandomAccessFile> infile);

Expand Down
42 changes: 42 additions & 0 deletions cpp/src/lance/format/schema_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "lance/testing/json.h"

using lance::arrow::ToArray;
using lance::format::Schema;
using lance::testing::MakeDataset;
using lance::testing::TableFromJSON;

Expand Down Expand Up @@ -217,4 +218,45 @@ TEST_CASE("Test schema metadata") {
CHECK(dataset->schema()->metadata());
CHECK(dataset->schema()->metadata()->Get("k1").ValueOrDie() == "v1");
CHECK(dataset->schema()->metadata()->Get("k1").ValueOrDie() == "v1");
}

TEST_CASE("Test merge two schemas") {
auto base_schema = Schema(::arrow::schema(
{::arrow::field("a", ::arrow::int32()), ::arrow::field("b", ::arrow::utf8())}));
auto merged =
base_schema.Merge(*::arrow::schema({::arrow::field("c", ::arrow::list(::arrow::utf8()))}))
.ValueOrDie();
CHECK(merged->GetField("a")->id() == 0);
CHECK(merged->GetField("b")->id() == 1);
CHECK(merged->GetField("c")->id() == 2);
}

TEST_CASE("Test merge two structs") {
auto base_schema = Schema(::arrow::schema(
{::arrow::field("a", ::arrow::struct_({::arrow::field("b", ::arrow::int32())}))}));
auto schema_c =
::arrow::schema({::arrow::field("a",
::arrow::struct_({::arrow::field("b", ::arrow::int32()),
::arrow::field("c", ::arrow::int64())}))});
auto merged = base_schema.Merge(*schema_c).ValueOrDie();
auto a = merged->GetField("a");
CHECK(a->id() == 0);
auto expected_struct = ::arrow::struct_(
{::arrow::field("b", ::arrow::int32()), ::arrow::field("c", ::arrow::int64())});
INFO("Expected type: " << expected_struct->ToString() << " Got: " << a->type()->ToString());
CHECK(a->type()->Equals(*expected_struct));
CHECK(merged->GetField("a.b")->id() == 1);
CHECK(merged->GetField("a.c")->id() == 2);
}

TEST_CASE("Test merge two list of structs") {
auto base_schema = Schema(::arrow::schema({::arrow::field(
"a", ::arrow::list(::arrow::struct_({::arrow::field("b1", ::arrow::int32())})))}));
auto addon_schema = ::arrow::schema({::arrow::field(
"a", ::arrow::list(::arrow::struct_({::arrow::field("b2", ::arrow::utf8())})))});
auto merged = base_schema.Merge(*addon_schema).ValueOrDie();
CHECK(merged->ToArrow()->Equals(::arrow::schema(
{::arrow::field("a",
::arrow::list(::arrow::struct_({::arrow::field("b1", ::arrow::int32()),
::arrow::field("b2", ::arrow::utf8())})))})));
}

0 comments on commit e8922f9

Please sign in to comment.