Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic extension type handling #90

Merged
merged 8 commits into from
Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/lance/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ add_library(
utils.cc
utils.h
writer.cc
)
testing.h)
target_include_directories(arrow SYSTEM PRIVATE ${Protobuf_INCLUDE_DIR})

add_lance_test(api_test)
Expand Down
62 changes: 62 additions & 0 deletions cpp/src/lance/arrow/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include <memory>

#include "lance/arrow/type.h"
#include "lance/arrow/testing.h"
#include "lance/format/schema.h"


auto nested_schema = ::arrow::schema({::arrow::field("pk", ::arrow::int32()),
::arrow::field("objects",
Expand Down Expand Up @@ -78,4 +81,63 @@ TEST_CASE("Build Scanner with nested struct") {
CHECK(scanner->options()->batch_readahead == 1);

fmt::print("Scanner Options: {}\n", scanner->options()->filter.ToString());
}


std::shared_ptr<::arrow::Table> MakeTable() {
auto ext_type = std::make_shared<::lance::testing::ParametricType>(1);
::arrow::StringBuilder stringBuilder;
::arrow::Int32Builder intBuilder;

CHECK(stringBuilder.AppendValues({"train", "train", "split", "train"}).ok());
auto c1 = stringBuilder.Finish().ValueOrDie();
stringBuilder.Reset();

CHECK(intBuilder.AppendValues({1, 2, 3, 4}).ok());
auto c2 = intBuilder.Finish().ValueOrDie();
intBuilder.Reset();

auto schema = ::arrow::schema({arrow::field("c1", ::arrow::utf8()),
arrow::field("c2", ext_type)});

std::vector<std::shared_ptr<::arrow::Array>> cols;
cols.push_back(c1);
cols.push_back(::arrow::ExtensionType::WrapArray(ext_type, c2));
return ::arrow::Table::Make(std::move(schema), std::move(cols));
}

std::shared_ptr<::arrow::dataset::Scanner> MakeScanner(std::shared_ptr<::arrow::Table> table) {
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
auto scanner_builder = lance::arrow::ScannerBuilder(dataset);
scanner_builder.Limit(2);
scanner_builder.Project({"c2"});
// TODO how can extension types implement comparisons for filtering against storage type?
auto result = scanner_builder.Finish();
CHECK(result.ok());
auto scanner = result.ValueOrDie();
fmt::print("Projected: {}\n", scanner->options()->projected_schema);
return scanner;
}


TEST_CASE("Scanner with extension") {
auto table = MakeTable();
auto ext_type = std::make_shared<::lance::testing::ParametricType>(1);
::arrow::RegisterExtensionType(ext_type);
auto scanner = MakeScanner(table);

auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
INFO("Dataset schema is " << dataset->schema()->ToString());

auto schema = ::lance::format::Schema(dataset->schema());
INFO("Lance schema is " << schema.ToString());

auto expected_proj_schema = ::arrow::schema({::arrow::field("c2", ext_type)});
INFO("Expected schema: " << expected_proj_schema->ToString());
INFO("Actual schema: " << scanner->options()->projected_schema->ToString());
CHECK(expected_proj_schema->Equals(scanner->options()->projected_schema));

auto actual_table = scanner->ToTable().ValueOrDie();
CHECK(actual_table->schema()->Equals(expected_proj_schema));
CHECK(actual_table->GetColumnByName("c2")->type()->Equals(ext_type));
}
93 changes: 93 additions & 0 deletions cpp/src/lance/arrow/testing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2022 Lance Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arrow/type.h>
#include <arrow/extension_type.h>
#include <fmt/format.h>

#include <string>


namespace lance {
namespace testing {

class ImageType : public ::arrow::ExtensionType {
public:
ImageType()
: ::arrow::ExtensionType(::arrow::struct_({
::arrow::field("uri", ::arrow::utf8()),
::arrow::field("data", ::arrow::int32()),
})) {}

std::string extension_name() const override { return "image"; }

bool ExtensionEquals(const ::arrow::ExtensionType& other) const override {
return other.extension_name() == extension_name();
}

std::shared_ptr<::arrow::Array> MakeArray(
std::shared_ptr<::arrow::ArrayData> data) const override {
return std::make_shared<::arrow::ExtensionArray>(data);
}

::arrow::Result<std::shared_ptr<::arrow::DataType>> Deserialize(
std::shared_ptr<::arrow::DataType> storage_type, const std::string& serialized) const override {
if (serialized != "ext-struct-type-unique-code") {
return ::arrow::Status::Invalid("Type identifier did not match");
}
return std::make_shared<ImageType>();
}
std::string Serialize() const override { return "image-ext"; }
};

// A parametric type where the extension_name() is always the same
class ParametricType : public ::arrow::ExtensionType {
public:
explicit ParametricType(int32_t parameter)
: ::arrow::ExtensionType(::arrow::int32()), parameter_(parameter) {}

int32_t parameter() const { return parameter_; }

std::string extension_name() const override { return "parametric-type"; }

bool ExtensionEquals(const ::arrow::ExtensionType& other) const override {
const auto& other_ext = static_cast<const ::arrow::ExtensionType&>(other);
if (other_ext.extension_name() != this->extension_name()) {
return false;
}
return this->parameter() == static_cast<const ParametricType&>(other).parameter();
}

std::shared_ptr<::arrow::Array> MakeArray(
std::shared_ptr<::arrow::ArrayData> data) const override {
return std::make_shared<::arrow::ExtensionArray>(data);
}

::arrow::Result<std::shared_ptr<::arrow::DataType>> Deserialize(
std::shared_ptr<::arrow::DataType> storage_type, const std::string& serialized) const override {
const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
return std::make_shared<ParametricType>(parameter);
}

std::string Serialize() const override {
std::string result(" ");
memcpy(&result[0], &parameter_, sizeof(int32_t));
return result;
}

private:
int32_t parameter_;
};

}
}
14 changes: 12 additions & 2 deletions cpp/src/lance/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ std::string ToString(::arrow::TimeUnit::type unit) {
} // namespace

::arrow::Result<std::string> ToLogicalType(std::shared_ptr<::arrow::DataType> dtype) {
if (dtype->id() == ::arrow::Type::EXTENSION) {
changhiskhan marked this conversation as resolved.
Show resolved Hide resolved
auto ext_type = std::static_pointer_cast<::arrow::ExtensionType>(dtype);
return ToLogicalType(ext_type->storage_type());
}

if (is_list(dtype)) {
auto list_type = std::reinterpret_pointer_cast<::arrow::ListType>(dtype);
return is_struct(list_type->value_type()) ? "list.struct" : "list";
Expand Down Expand Up @@ -176,8 +181,13 @@ ::arrow::Result<std::shared_ptr<::arrow::DataType>> FromLogicalType(
"FromLogicalType: logical_type \"{}\" is not supported yet", logical_type.to_string()));
}

bool is_timestamp(std::shared_ptr<::arrow::DataType> dtype) {
return dtype->id() == ::arrow::TimestampType::type_id;

std::optional<std::string> GetExtensionName(std::shared_ptr<::arrow::DataType> dtype) {
if (dtype->id() == ::arrow::Type::EXTENSION) {
auto ext_type = std::static_pointer_cast<::arrow::ExtensionType>(dtype);
return ext_type->extension_name();
}
return std::nullopt;
}

} // namespace lance::arrow
11 changes: 10 additions & 1 deletion cpp/src/lance/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <concepts>
#include <memory>
#include <optional>
#include <string>
#include <vector>

Expand Down Expand Up @@ -65,12 +66,20 @@ inline bool is_map(std::shared_ptr<::arrow::DataType> dtype) {
}

/// Returns True if the data type is timestamp type.
bool is_timestamp(std::shared_ptr<::arrow::DataType> dtype);
inline bool is_timestamp(std::shared_ptr<::arrow::DataType> dtype) {
return dtype->id() == ::arrow::TimestampType::type_id;
}

inline bool is_extension(std::shared_ptr<::arrow::DataType> dtype) {
return dtype->id() == ::arrow::Type::EXTENSION;
}

/// Convert arrow DataType to a string representation.
::arrow::Result<std::string> ToLogicalType(std::shared_ptr<::arrow::DataType> dtype);

::arrow::Result<std::shared_ptr<::arrow::DataType>> FromLogicalType(
::arrow::util::string_view logical_type);

std::optional<std::string> GetExtensionName(std::shared_ptr<::arrow::DataType> dtype);

} // namespace lance::arrow
72 changes: 71 additions & 1 deletion cpp/src/lance/arrow/writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@

#include "lance/arrow/reader.h"
#include "lance/arrow/type.h"
#include "lance/format/schema.h"
#include "lance/arrow/testing.h"
#include "lance/io/reader.h"

using arrow::ArrayBuilder;
using arrow::Int32Builder;
using arrow::ListBuilder;
using arrow::StringBuilder;
using arrow::StructBuilder;
using arrow::LargeBinaryBuilder;
using arrow::Table;
using lance::arrow::FileReader;
using lance::format::Schema;

using std::make_shared;
using std::map;
Expand Down Expand Up @@ -172,6 +176,7 @@ TEST_CASE("Write dictionary type") {
CHECK(table->Equals(*actual_table));
}


TEST_CASE("Large binary field") {
auto field_type = ::arrow::large_binary();
auto schema = ::arrow::schema({arrow::field("f1", field_type)});
Expand Down Expand Up @@ -202,4 +207,69 @@ TEST_CASE("Binary field") {

auto sink = arrow::io::BufferOutputStream::Create().ValueOrDie();
CHECK(lance::arrow::WriteTable(*table, sink).ok());
}
}


std::shared_ptr<::arrow::Table> MakeTable() {
auto ext_type = std::make_shared<::lance::testing::ImageType>();
auto uriBuilder = std::make_shared<::arrow::StringBuilder>();
auto dataBuilder = std::make_shared<::arrow::Int32Builder>();
auto imageBuilder = std::make_shared<::arrow::StructBuilder>(
ext_type->storage_type(),
arrow::default_memory_pool(),
std::vector<std::shared_ptr<::arrow::ArrayBuilder>>({uriBuilder, dataBuilder}));
for (int i = 0; i < 4; i++) {
CHECK(imageBuilder->Append().ok());
CHECK(uriBuilder->Append(fmt::format("s3://{}", i)).ok());
CHECK(dataBuilder->Append(i).ok());
}
auto arr = imageBuilder->Finish().ValueOrDie();
INFO("array is " << arr->ToString());

auto schema = ::arrow::schema({arrow::field("image_ext", ext_type)});
std::vector<std::shared_ptr<::arrow::Array>> cols;
cols.push_back(::arrow::ExtensionType::WrapArray(ext_type, arr));
return ::arrow::Table::Make(std::move(schema), std::move(cols));
}

std::shared_ptr<::arrow::Table> ReadTable(std::shared_ptr<arrow::io::BufferOutputStream> sink) {
auto infile = make_shared<arrow::io::BufferReader>(sink->Finish().ValueOrDie());
INFO(::lance::arrow::FileReader::Make(infile).status());
auto reader = ::lance::arrow::FileReader::Make(infile).ValueOrDie();
CHECK(reader->num_batches() == 1);
CHECK(reader->length() == 4);
return reader->ReadTable().ValueOrDie();
}

TEST_CASE("Write extension but read storage if not registered") {
auto table = MakeTable();
auto arr = std::static_pointer_cast<::arrow::ExtensionArray>(
table->GetColumnByName("image_ext")->chunk(0))->storage();

auto sink = arrow::io::BufferOutputStream::Create().ValueOrDie();
CHECK(lance::arrow::WriteTable(*table, sink).ok());

// We can read it back without the extension
auto actual_table = ReadTable(sink);
CHECK(arr->Equals(actual_table->GetColumnByName("image_ext")->chunk(0)));
auto lance_schema = Schema(actual_table->schema());
auto image_field = lance_schema.GetField("image_ext");
CHECK(image_field->logical_type() == "struct");
CHECK(image_field->extension_name() == "");
CHECK(!(image_field->is_extension_type()));
CHECK(lance_schema.GetFieldsCount() == 3);
}


TEST_CASE("Extension type round-trip") {
auto ext_type = std::make_shared<::lance::testing::ImageType>();
arrow::RegisterExtensionType(ext_type);
auto table = MakeTable();
auto sink = arrow::io::BufferOutputStream::Create().ValueOrDie();
CHECK(lance::arrow::WriteTable(*table, sink).ok());

// We can read it back without the extension
auto actual_table = ReadTable(sink);
CHECK(table->Equals(*actual_table));
}

Loading