diff --git a/cpp/src/lance/arrow/type.cc b/cpp/src/lance/arrow/type.cc index 2980278b84..68eef385c6 100644 --- a/cpp/src/lance/arrow/type.cc +++ b/cpp/src/lance/arrow/type.cc @@ -62,6 +62,11 @@ ::arrow::Result ToLogicalType(std::shared_ptr<::arrow::DataType> dt } else if (::arrow::is_fixed_size_binary(dtype->id())) { auto fixed_type = std::reinterpret_pointer_cast<::arrow::FixedSizeBinaryType>(dtype); return fmt::format("fixed_size_binary:{}", fixed_type->byte_width()); + } else if (is_fixed_size_list(dtype)) { + auto list_type = std::dynamic_pointer_cast<::arrow::FixedSizeListType>(dtype); + assert(::arrow::is_primitive(list_type->value_type()->id())); + ARROW_ASSIGN_OR_RAISE(auto value_type, ToLogicalType(list_type->value_type())); + return fmt::format("fixed_size_list:{}:{}", value_type, list_type->list_size()); } else if (dtype->id() == ::arrow::Date32Type::type_id) { return "date32:day"; } else if (dtype->id() == ::arrow::Date64Type::type_id) { @@ -167,6 +172,21 @@ ::arrow::Result> FromLogicalType( return ::arrow::fixed_size_binary(size); } + if (logical_type.starts_with("fixed_size_list:")) { + auto components = ::arrow::internal::SplitString(logical_type, ':'); + if (components.size() != 3) { + return ::arrow::Status::Invalid( + fmt::format("Invalid fixed size list string: {}", logical_type.to_string())); + } + ARROW_ASSIGN_OR_RAISE(auto value_type, FromLogicalType(components[1])); + auto size = std::stoi(components[2].to_string()); + if (size == 0) { + return ::arrow::Status::Invalid( + fmt::format("Invalid fixe size binary string: {}", logical_type.to_string())); + } + return ::arrow::fixed_size_list(value_type, size); + } + if (logical_type.starts_with("dict")) { auto components = ::arrow::internal::SplitString(logical_type, ':'); if (components.size() != 4) { diff --git a/cpp/src/lance/arrow/type_test.cc b/cpp/src/lance/arrow/type_test.cc index 8312f46131..478f648820 100644 --- a/cpp/src/lance/arrow/type_test.cc +++ b/cpp/src/lance/arrow/type_test.cc @@ -75,6 +75,7 @@ TEST_CASE("Logical type coverage") { {::arrow::time32(::arrow::TimeUnit::MILLI), "time32:ms"}, {::arrow::time64(::arrow::TimeUnit::MICRO), "time64:us"}, {::arrow::time64(::arrow::TimeUnit::NANO), "time64:ns"}, + {::arrow::fixed_size_list(::arrow::int32(), 4), "fixed_size_list:int32:4"}, }); for (auto& [arrow_type, type_str] : kArrayTypeMap) { diff --git a/cpp/src/lance/format/schema.cc b/cpp/src/lance/format/schema.cc index 4819315098..be536b6861 100644 --- a/cpp/src/lance/format/schema.cc +++ b/cpp/src/lance/format/schema.cc @@ -69,7 +69,8 @@ void Field::Init(std::shared_ptr<::arrow::DataType> dtype) { auto type_id = dtype->id(); if (::arrow::is_binary_like(type_id) || ::arrow::is_large_binary_like(type_id)) { encoding_ = pb::VAR_BINARY; - } else if (::arrow::is_primitive(type_id)) { + } else if (::arrow::is_primitive(type_id) || ::arrow::is_fixed_size_binary(type_id) || + lance::arrow::is_fixed_size_list(dtype)) { encoding_ = pb::PLAIN; } else if (::arrow::is_dictionary(type_id)) { encoding_ = pb::DICTIONARY; diff --git a/cpp/src/lance/format/schema_test.cc b/cpp/src/lance/format/schema_test.cc index 7a825bb48e..70b596883f 100644 --- a/cpp/src/lance/format/schema_test.cc +++ b/cpp/src/lance/format/schema_test.cc @@ -1,11 +1,12 @@ #include "lance/format/schema.h" -#include "lance/arrow/testing.h" #include #include #include +#include "lance/arrow/testing.h" + const auto arrow_schema = ::arrow::schema( {::arrow::field("pk", ::arrow::utf8()), ::arrow::field("split", ::arrow::utf8()), @@ -20,12 +21,9 @@ const auto arrow_schema = ::arrow::schema( ::arrow::field("ymax", ::arrow::float32()), }))})))}); - std::shared_ptr<::arrow::DataType> image_type = std::make_shared<::lance::testing::ImageType>(); -const auto ext_schema = ::arrow::schema( - {::arrow::field("pk", ::arrow::utf8()), - ::arrow::field("image", image_type)}); - +const auto ext_schema = + ::arrow::schema({::arrow::field("pk", ::arrow::utf8()), ::arrow::field("image", image_type)}); TEST_CASE("Get field by name") { auto schema = lance::format::Schema(arrow_schema); @@ -129,12 +127,25 @@ TEST_CASE("Project nested extension type") { auto original = lance::format::Schema(ext_schema); auto projection = original.Project({"image.uri"}).ValueOrDie(); - auto expected_schema = ::arrow::schema({::arrow::field( - "image", - ::arrow::struct_({::arrow::field("uri", ::arrow::utf8())}))}); + auto expected_schema = ::arrow::schema( + {::arrow::field("image", ::arrow::struct_({::arrow::field("uri", ::arrow::utf8())}))}); auto expected = lance::format::Schema(expected_schema); INFO("Expected: " << expected.ToString()); INFO("Actual: " << projection->ToString()); CHECK(projection->Equals(expected, false)); } +TEST_CASE("Fixed size list") { + auto arrow_field = + ::arrow::field("fixed_size_list", ::arrow::fixed_size_list(::arrow::int32(), 4)); + auto field = ::lance::format::Field(arrow_field); + CHECK(field.encoding() == ::lance::format::pb::PLAIN); + CHECK(field.logical_type() == "fixed_size_list:int32:4"); +} + +TEST_CASE("Fixed size binary") { + auto arrow_field = ::arrow::field("fs_binary", ::arrow::fixed_size_binary(100)); + auto field = ::lance::format::Field(arrow_field); + CHECK(field.encoding() == ::lance::format::pb::PLAIN); + CHECK(field.logical_type() == "fixed_size_binary:100"); +} \ No newline at end of file