Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C++] Fix reading dictionary values from manifest files #314

Merged
merged 10 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions cpp/src/lance/arrow/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

#include "lance/arrow/dataset.h"

#include <arrow/array.h>
#include <arrow/dataset/api.h>
#include <arrow/status.h>
#include <arrow/table.h>
#include <fmt/format.h>
#include <fmt/ranges.h>
#include <uuid.h>
Expand All @@ -31,6 +33,7 @@
#include "lance/arrow/fragment.h"
#include "lance/format/manifest.h"
#include "lance/format/schema.h"
#include "lance/io/reader.h"
#include "lance/io/writer.h"

namespace fs = std::filesystem;
Expand Down Expand Up @@ -82,7 +85,56 @@ std::string GetBasenameTemplate() {
::arrow::Result<std::shared_ptr<lance::format::Manifest>> OpenManifest(
const std::shared_ptr<::arrow::fs::FileSystem>& fs, const std::string& path) {
ARROW_ASSIGN_OR_RAISE(auto in, fs->OpenInputFile(path));
return lance::format::Manifest::Parse(in, 0);
return lance::io::FileReader::OpenManifest(in);
}

::arrow::Status CollectDictionary(const std::shared_ptr<lance::format::Field>& field,
const std::shared_ptr<::arrow::Array>& arr) {
assert(field && arr);
assert(field->type()->Equals(arr->type()));
auto data_type = field->type();
if (::arrow::is_dictionary(data_type->id())) {
auto dict_arr = std::dynamic_pointer_cast<::arrow::DictionaryArray>(arr);
return field->set_dictionary(dict_arr->dictionary());
}

if (is_list(data_type)) {
auto list_arr = std::dynamic_pointer_cast<::arrow::ListArray>(arr);
ARROW_RETURN_NOT_OK(CollectDictionary(field->field(0), list_arr->values()));
} else if (is_struct(data_type)) {
auto struct_arr = std::dynamic_pointer_cast<::arrow::StructArray>(arr);
for (auto& child : field->fields()) {
auto child_arr = struct_arr->GetFieldByName(child->name());
if (child_arr == nullptr) {
return ::arrow::Status::Invalid("CollectDictionary: schema mismatch: field ",
child->name(),
"does not exist in the table: ",
struct_arr->type());
}
ARROW_RETURN_NOT_OK(CollectDictionary(child, child_arr));
}
}
return ::arrow::Status::OK();
}

::arrow::Status CollectDictionary(const std::shared_ptr<lance::format::Schema>& schema,
const std::shared_ptr<::arrow::dataset::Scanner>& scanner) {
ARROW_ASSIGN_OR_RAISE(auto example, scanner->Head(1));
if (example->num_rows() == 0) {
return ::arrow::Status::Invalid("CollectDictionary: empty dataset");
}
for (auto& field : schema->fields()) {
auto chunked_arr = example->GetColumnByName(field->name());
if (chunked_arr == nullptr) {
return ::arrow::Status::Invalid("CollectDictionary: schema mismatch: field ",
field->name(),
"does not exist in the table: ",
example->schema());
}
assert(chunked_arr->num_chunks() > 0);
ARROW_RETURN_NOT_OK(CollectDictionary(field, chunked_arr->chunk(0)));
}
return ::arrow::Status::OK();
}

} // namespace
Expand Down Expand Up @@ -164,6 +216,7 @@ ::arrow::Status LanceDataset::Write(const ::arrow::dataset::FileSystemDatasetWri
auto schema = std::make_shared<lance::format::Schema>(scanner->options()->dataset_schema);
manifest = std::make_shared<lance::format::Manifest>(schema);
}
ARROW_RETURN_NOT_OK(CollectDictionary(manifest->schema(), scanner));

// Write manifest file
auto lance_option = options;
Expand Down Expand Up @@ -205,7 +258,7 @@ ::arrow::Status LanceDataset::Write(const ::arrow::dataset::FileSystemDatasetWri
auto manifest_path = GetManifestPath(base_dir, manifest->version());
{
ARROW_ASSIGN_OR_RAISE(auto out, fs->OpenOutputStream(manifest_path));
ARROW_RETURN_NOT_OK(manifest->Write(out));
ARROW_RETURN_NOT_OK(lance::io::FileWriter::WriteManifest(out, *manifest));
}
auto latest_manifest_path = GetManifestPath(base_dir, std::nullopt);
return fs->CopyFile(manifest_path, latest_manifest_path);
Expand Down Expand Up @@ -263,7 +316,7 @@ ::arrow::Result<::arrow::dataset::FragmentIterator> LanceDataset::GetFragmentsIm
std::vector<std::shared_ptr<::arrow::dataset::Fragment>> fragments =
impl_->manifest->fragments() | views::transform([this](auto& data_fragment) {
return std::make_shared<LanceFragment>(
impl_->fs, impl_->data_dir(), data_fragment, impl_->manifest->schema());
impl_->fs, impl_->data_dir(), data_fragment, impl_->manifest);
}) |
ranges::to<decltype(fragments)>;

Expand Down
39 changes: 37 additions & 2 deletions cpp/src/lance/arrow/dataset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

#include <catch2/catch_test_macros.hpp>
#include <memory>
#include <string>
#include <range/v3/view.hpp>
#include <string>

#include "lance/arrow/file_lance.h"
#include "lance/arrow/stl.h"
Expand All @@ -29,7 +29,6 @@
using lance::arrow::ToArray;
using namespace ranges::views;


std::shared_ptr<::arrow::Table> ReadTable(const std::string& uri, std::optional<int32_t> version) {
std::string path;
auto fs = ::arrow::fs::FileSystemFromUriOrPath(uri, &path).ValueOrDie();
Expand All @@ -38,6 +37,24 @@ std::shared_ptr<::arrow::Table> ReadTable(const std::string& uri, std::optional<
return actual_dataset->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie();
}

// Write table as dataset.
std::string WriteTable(const std::shared_ptr<::arrow::Table>& table) {
auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata";
auto format = lance::arrow::LanceFileFormat::Make();
::arrow::dataset::FileSystemDatasetWriteOptions write_options;
std::string path;
auto fs = ::arrow::fs::FileSystemFromUriOrPath(base_uri, &path).ValueOrDie();
write_options.filesystem = fs;
write_options.base_dir = path;
write_options.file_write_options = format->DefaultWriteOptions();

auto dataset = lance::testing::MakeDataset(table).ValueOrDie();
CHECK(lance::arrow::LanceDataset::Write(write_options,
dataset->NewScan().ValueOrDie()->Finish().ValueOrDie())
.ok());
return base_uri;
}

TEST_CASE("Create new dataset") {
auto ids = ToArray({1, 2, 3, 4, 5, 6, 8}).ValueOrDie();
auto values = ToArray({"a", "b", "c", "d", "e", "f", "g"}).ValueOrDie();
Expand Down Expand Up @@ -199,4 +216,22 @@ TEST_CASE("Dataset overwrite error cases") {
lance::arrow::LanceDataset::kOverwrite);
INFO("Status: " << status.message() << " is ok: " << status.ok());
CHECK(status.IsIOError());
}

TEST_CASE("Dataset write dictionary array") {
auto dict_values = ToArray({"a", "b", "c"}).ValueOrDie();
auto dict_indices = ToArray({0, 1, 1, 2, 2, 0}).ValueOrDie();
auto data_type = ::arrow::dictionary(::arrow::int32(), ::arrow::utf8());
auto dict_arr =
::arrow::DictionaryArray::FromArrays(
data_type, dict_indices, dict_values)
.ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("dict", data_type)}), {dict_arr});

auto base_uri = WriteTable(table);
fmt::print("Base URI: {}\n", base_uri);

auto actual = ReadTable(base_uri, 1);
CHECK(actual->Equals(*table));
}
2 changes: 1 addition & 1 deletion cpp/src/lance/arrow/file_lance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ ::arrow::Future<std::optional<int64_t>> LanceFileFormat::CountRows(
::arrow::Result<::arrow::RecordBatchGenerator> LanceFileFormat::ScanBatchesAsync(
const std::shared_ptr<::arrow::dataset::ScanOptions>& options,
const std::shared_ptr<::arrow::dataset::FileFragment>& file) const {
ARROW_ASSIGN_OR_RAISE(auto fragment, LanceFragment::Make(*file, impl_->manifest->schema()));
ARROW_ASSIGN_OR_RAISE(auto fragment, LanceFragment::Make(*file, impl_->manifest));
ARROW_ASSIGN_OR_RAISE(auto batch_reader,
lance::io::RecordBatchReader::Make(*fragment, options));
return ::arrow::RecordBatchGenerator(std::move(batch_reader));
Expand Down
21 changes: 13 additions & 8 deletions cpp/src/lance/arrow/fragment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "lance/arrow/file_lance.h"
#include "lance/arrow/utils.h"
#include "lance/format/data_fragment.h"
#include "lance/format/manifest.h"
#include "lance/format/schema.h"
#include "lance/io/reader.h"
#include "lance/io/record_batch_reader.h"
Expand All @@ -33,22 +34,23 @@ namespace fs = std::filesystem;
namespace lance::arrow {

::arrow::Result<std::shared_ptr<LanceFragment>> LanceFragment::Make(
const ::arrow::dataset::FileFragment& file_fragment, std::shared_ptr<format::Schema> schema) {
auto field_ids = schema->GetFieldIds();
const ::arrow::dataset::FileFragment& file_fragment,
std::shared_ptr<format::Manifest> manifest) {
auto field_ids = manifest->schema()->GetFieldIds();
auto data_fragment = std::make_shared<format::DataFragment>(
format::DataFile(file_fragment.source().path(), field_ids));
return std::make_shared<LanceFragment>(
file_fragment.source().filesystem(), "", data_fragment, schema);
file_fragment.source().filesystem(), "", std::move(data_fragment), std::move(manifest));
}

LanceFragment::LanceFragment(std::shared_ptr<::arrow::fs::FileSystem> fs,
std::string data_dir,
std::shared_ptr<lance::format::DataFragment> fragment,
std::shared_ptr<format::Schema> schema)
std::shared_ptr<lance::format::Manifest> manifest)
: fs_(std::move(fs)),
data_uri_(std::move(data_dir)),
fragment_(std::move(fragment)),
schema_(std::move(schema)) {}
manifest_(std::move(manifest)) {}

::arrow::Result<::arrow::RecordBatchGenerator> LanceFragment::ScanBatchesAsync(
const std::shared_ptr<::arrow::dataset::ScanOptions>& options) {
Expand All @@ -57,7 +59,7 @@ ::arrow::Result<::arrow::RecordBatchGenerator> LanceFragment::ScanBatchesAsync(
}

::arrow::Result<std::shared_ptr<::arrow::Schema>> LanceFragment::ReadPhysicalSchemaImpl() {
return schema_->ToArrow();
return schema()->ToArrow();
}

::arrow::Result<std::vector<LanceFragment::FileReaderWithSchema>> LanceFragment::Open(
Expand All @@ -71,14 +73,15 @@ ::arrow::Result<std::vector<LanceFragment::FileReaderWithSchema>> LanceFragment:
executor->Submit(
[this, &schema](auto idx) -> ::arrow::Result<FileReaderWithSchema> {
auto& data_file = this->fragment_->data_files()[idx];
ARROW_ASSIGN_OR_RAISE(auto data_file_schema, schema_->Project(data_file.fields()));
ARROW_ASSIGN_OR_RAISE(auto data_file_schema,
this->schema()->Project(data_file.fields()));
ARROW_ASSIGN_OR_RAISE(auto intersection, schema.Intersection(*data_file_schema));
if (intersection->fields().empty()) {
return std::make_tuple(nullptr, nullptr);
}
auto full_path = (fs::path(data_uri_) / data_file.path()).string();
ARROW_ASSIGN_OR_RAISE(auto infile, fs_->OpenInputFile(full_path))
ARROW_ASSIGN_OR_RAISE(auto reader, lance::io::FileReader::Make(infile));
ARROW_ASSIGN_OR_RAISE(auto reader, lance::io::FileReader::Make(infile, this->manifest_));
return std::make_tuple(std::move(reader), intersection);
},
i));
Expand All @@ -96,6 +99,8 @@ ::arrow::Result<std::vector<LanceFragment::FileReaderWithSchema>> LanceFragment:
return readers;
}

const std::shared_ptr<format::Schema>& LanceFragment::schema() const { return manifest_->schema(); }

::arrow::Result<int64_t> LanceFragment::FastCountRow() const {
assert(!fragment_->data_files().empty());
ARROW_ASSIGN_OR_RAISE(auto reader, OpenReader(0));
Expand Down
13 changes: 7 additions & 6 deletions cpp/src/lance/arrow/fragment.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,22 @@ class LanceFragment : public ::arrow::dataset::Fragment {
///
/// It creates a LanceFragment from `arrow.dataset.FileFragment`.
/// \param file_fragment plain dataset file fragment
/// \param schema the schema of dataset.
/// \param manifest dataset manifest.
/// \return LanceFragment
static ::arrow::Result<std::shared_ptr<LanceFragment>> Make(
const ::arrow::dataset::FileFragment& file_fragment, std::shared_ptr<format::Schema> schema);
const ::arrow::dataset::FileFragment& file_fragment,
std::shared_ptr<format::Manifest> manifest);

/// Constructor
///
/// \param fs a file system instance to conduct IOs.
/// \param data_dir the base directory to store data.
/// \param fragment data fragment, the metadata of the fragment.
/// \param schema the schema of the Fragment.
/// \param manifest dataset manifest.
LanceFragment(std::shared_ptr<::arrow::fs::FileSystem> fs,
std::string data_dir,
std::shared_ptr<lance::format::DataFragment> fragment,
std::shared_ptr<format::Schema> schema);
std::shared_ptr<lance::format::Manifest> manifest);

/// Destructor.
~LanceFragment() override = default;
Expand All @@ -79,7 +80,7 @@ class LanceFragment : public ::arrow::dataset::Fragment {
::arrow::internal::Executor* executor = ::arrow::internal::GetCpuThreadPool()) const;

/// Dataset schema.
const std::shared_ptr<format::Schema>& schema() const { return schema_; }
const std::shared_ptr<format::Schema>& schema() const;

protected:
::arrow::Result<std::shared_ptr<::arrow::Schema>> ReadPhysicalSchemaImpl() override;
Expand All @@ -98,7 +99,7 @@ class LanceFragment : public ::arrow::dataset::Fragment {
std::shared_ptr<::arrow::fs::FileSystem> fs_;
std::string data_uri_;
std::shared_ptr<lance::format::DataFragment> fragment_;
std::shared_ptr<format::Schema> schema_;
std::shared_ptr<format::Manifest> manifest_;
};

} // namespace lance::arrow
6 changes: 6 additions & 0 deletions cpp/src/lance/format/manifest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ ::arrow::Result<std::shared_ptr<Manifest>> Manifest::Parse(
return std::shared_ptr<Manifest>(new Manifest(pb));
}

::arrow::Result<std::shared_ptr<Manifest>> Manifest::Parse(
const std::shared_ptr<::arrow::Buffer>& buffer) {
ARROW_ASSIGN_OR_RAISE(auto pb, io::ParseProto<pb::Manifest>(buffer));
return std::shared_ptr<Manifest>(new Manifest(pb));
}

::arrow::Result<int64_t> Manifest::Write(std::shared_ptr<::arrow::io::OutputStream> out) const {
lance::format::pb::Manifest pb;
for (auto field : schema_->ToProto()) {
Expand Down
7 changes: 6 additions & 1 deletion cpp/src/lance/format/manifest.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <arrow/buffer.h>
#include <arrow/io/api.h>
#include <arrow/result.h>

Expand Down Expand Up @@ -54,6 +55,10 @@ class Manifest final {
static ::arrow::Result<std::shared_ptr<Manifest>> Parse(
std::shared_ptr<::arrow::io::RandomAccessFile> in, int64_t offset);

/// Parse a Manifest from a buffer.
static ::arrow::Result<std::shared_ptr<Manifest>> Parse(
const std::shared_ptr<::arrow::Buffer>& buffer);

/// Write the Manifest to a file.
///
/// \param out the output stream to write this Manifest to.
Expand Down Expand Up @@ -87,7 +92,7 @@ class Manifest final {

std::vector<std::shared_ptr<DataFragment>> fragments_;

Manifest(const lance::format::pb::Manifest& pb);
explicit Manifest(const lance::format::pb::Manifest& pb);
};

} // namespace lance::format
24 changes: 17 additions & 7 deletions cpp/src/lance/format/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ void Field::Init(std::shared_ptr<::arrow::DataType> dtype) {
if (::lance::arrow::is_struct(dtype)) {
auto struct_type = std::static_pointer_cast<::arrow::StructType>(dtype);
for (auto& arrow_field : struct_type->fields()) {
children_.push_back(std::shared_ptr<Field>(new Field(arrow_field)));
children_.push_back(std::make_shared<Field>(arrow_field));
}
} else if (::lance::arrow::is_list(dtype)) {
auto list_type = std::static_pointer_cast<::arrow::ListType>(dtype);
children_.emplace_back(
std::shared_ptr<Field>(new Field(::arrow::field("item", list_type->value_type()))));
std::make_shared<Field>(::arrow::field("item", list_type->value_type())));
encoding_ = encodings::PLAIN;
} else if (::arrow::is_binary_like(type_id) || ::arrow::is_large_binary_like(type_id)) {
encoding_ = encodings::VAR_BINARY;
Expand All @@ -87,9 +87,12 @@ Field::Field(const pb::Field& pb)
name_(pb.name()),
logical_type_(pb.logical_type()),
extension_name_(pb.extension_name()),
encoding_(lance::encodings::FromProto(pb.encoding())),
dictionary_offset_(pb.dictionary_offset()),
dictionary_page_length_(pb.dictionary_page_length()) {}
encoding_(lance::encodings::FromProto(pb.encoding())) {
if (pb.has_dictionary()) {
dictionary_offset_ = pb.dictionary().offset();
dictionary_page_length_ = pb.dictionary().length();
}
}

void Field::AddChild(std::shared_ptr<Field> child) { children_.emplace_back(child); }

Expand Down Expand Up @@ -165,6 +168,9 @@ std::string Field::ToString() const {
if (is_extension_type()) {
str = fmt::format("{}, extension_name={}", str, extension_name_);
}
if (dictionary_) {
str = fmt::format("{}, dict={}", str, dictionary_->ToString());
}
return str;
}

Expand Down Expand Up @@ -291,8 +297,12 @@ std::vector<lance::format::pb::Field> Field::ToProto() const {
field.set_logical_type(logical_type_);
field.set_extension_name(extension_name_);
field.set_encoding(::lance::encodings::ToProto(encoding_));
field.set_dictionary_offset(dictionary_offset_);
field.set_dictionary_page_length(dictionary_page_length_);

if (dictionary_offset_ >= 0) {
field.mutable_dictionary()->set_offset(dictionary_offset_);
field.mutable_dictionary()->set_length(dictionary_page_length_);
}

field.set_type(GetNodeType());

pb_fields.emplace_back(field);
Expand Down
Loading