Skip to content

Commit

Permalink
[Python] Test projection in Python Torch Dataset (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Sep 27, 2022
1 parent 10d53e3 commit 9a6a4d7
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 112 deletions.
1 change: 0 additions & 1 deletion cpp/src/lance/arrow/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ TEST_CASE("Test filter with smaller batch size than block size") {
CHECK(scan_builder.BatchSize(7).ok()); // Some number that is not dividable by the group size.
auto scanner = scan_builder.Finish().ValueOrDie();
auto actual = scanner->ToTable().ValueOrDie();
fmt::print("Actual table: {}\n", actual->ToString());

std::vector<std::string> expected_strs;
for (size_t i = 0; i < ints.size() / 5; i++) {
Expand Down
59 changes: 0 additions & 59 deletions cpp/src/lance/arrow/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,63 +104,4 @@ ::arrow::Result<std::shared_ptr<::arrow::StructArray>> MergeStructArrays(
return ::arrow::StructArray::Make(arrays, names);
}

std::string ColumnNameFromFieldRef(const ::arrow::FieldRef& ref) {
if (ref.IsName()) {
return *ref.name();
}
assert(ref.IsNested());
std::string name;
for (auto& child : *ref.nested_refs()) {
if (child.IsFieldPath()) {
continue;
}
if (child.IsName()) {
if (!name.empty()) {
name += ".";
}
name += *child.name();
}
}
return name;
}

::arrow::Result<std::shared_ptr<::arrow::StructArray>> ApplyProjection(
const ::arrow::StructArray& arr, const format::Field& field) {
assert(is_struct(field.type()->id()));

std::vector<std::shared_ptr<::arrow::Array>> columns;
std::vector<std::string> names;
for (auto& child : field.fields()) {
auto col = arr.GetFieldByName(child->name());
if (!col) {
return ::arrow::Status::Invalid(fmt::format("Column {} not found", child->name()));
}
if (is_struct(col->type_id())) {
ARROW_ASSIGN_OR_RAISE(
col, ApplyProjection(*std::dynamic_pointer_cast<::arrow::StructArray>(col), *child));
}
names.emplace_back(child->name());
columns.emplace_back(col);
}
return ::arrow::StructArray::Make(columns, names);
};

::arrow::Result<std::shared_ptr<::arrow::RecordBatch>> ApplyProjection(
const std::shared_ptr<::arrow::RecordBatch>& batch, const format::Schema& projected_schema) {
std::vector<std::shared_ptr<::arrow::Array>> columns;
for (const auto& field : projected_schema.fields()) {
auto col = batch->GetColumnByName(field->name());
if (!col) {
continue;
}
if (is_struct(col->type_id())) {
ARROW_ASSIGN_OR_RAISE(
col, ApplyProjection(*std::dynamic_pointer_cast<::arrow::StructArray>(col), *field));
}

columns.emplace_back(col);
}
return ::arrow::RecordBatch::Make(projected_schema.ToArrow(), batch->num_rows(), columns);
}

} // namespace lance::arrow
4 changes: 0 additions & 4 deletions cpp/src/lance/arrow/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,5 @@ ::arrow::Result<std::shared_ptr<::arrow::StructArray>> MergeStructArrays(
const std::shared_ptr<::arrow::StructArray>& rhs,
::arrow::MemoryPool* pool = ::arrow::default_memory_pool());

::arrow::Result<std::shared_ptr<::arrow::RecordBatch>> ApplyProjection(
const std::shared_ptr<::arrow::RecordBatch>& batch, const format::Schema& projected_schema);

std::string ColumnNameFromFieldRef(const ::arrow::FieldRef& ref);

} // namespace lance::arrow
39 changes: 0 additions & 39 deletions cpp/src/lance/arrow/utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,42 +95,3 @@ TEST_CASE("Merge nested structs") {
INFO("Actual data: " << points->ToString() << " Expected: " << expected_arr->ToString());
CHECK(points->Equals(expected_arr));
}

TEST_CASE("Apply projection") {
std::vector<std::string> names({"one", "two", "three"});
auto int_values = ToArray({1, 2, 3}).ValueOrDie();
auto str_values = ToArray(names).ValueOrDie();

auto ids_builder = std::make_shared<::arrow::Int16Builder>();
auto names_builder = std::make_shared<::arrow::StringBuilder>();
auto struct_type = arrow::struct_(
{::arrow::field("ids", ::arrow::int16()), ::arrow::field("names", ::arrow::utf8())});
auto annotations_builder = ::arrow::StructBuilder(
struct_type, ::arrow::default_memory_pool(), {ids_builder, names_builder});

for (int i = 1; i <= 3; i++) {
CHECK(annotations_builder.Append().ok());
CHECK(ids_builder->Append(i).ok());
CHECK(names_builder->Append(names[i - 1]).ok());
}
auto ann_arr = annotations_builder.Finish().ValueOrDie();

auto arrow_schema = ::arrow::schema({::arrow::field("ints", ::arrow::int32()),
::arrow::field("strings", ::arrow::utf8()),
::arrow::field("annotations", struct_type)});
auto schema = ::lance::format::Schema(arrow_schema);
auto record_batch =
::arrow::RecordBatch::Make(arrow_schema, 3, {int_values, str_values, ann_arr});
auto projected_schema = schema.Project({"ints", "annotations.names"}).ValueOrDie();
auto projected = lance::arrow::ApplyProjection(record_batch, *projected_schema).ValueOrDie();
CHECK(projected->schema()->Equals(projected_schema->ToArrow()));
CHECK(projected->GetColumnByName("strings") == nullptr);
CHECK(projected->GetColumnByName("ints")->Equals(int_values));

auto actual_ann_arr = projected->GetColumnByName("annotations");
CHECK(actual_ann_arr);
CHECK(lance::arrow::is_struct(actual_ann_arr->type_id()));
auto ann_struct_arr = std::dynamic_pointer_cast<::arrow::StructArray>(actual_ann_arr);
CHECK(ann_struct_arr->GetFieldByName("names")->length() == 3);
CHECK(ann_struct_arr->GetFieldByName("ids") == nullptr);
}
6 changes: 0 additions & 6 deletions cpp/src/lance/io/exec/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,4 @@ int64_t ScanBatch::length() const {
return batch->num_rows();
}

::arrow::Result<ScanBatch> ScanBatch::Project(const lance::format::Schema& projected_schema) {
ARROW_ASSIGN_OR_RAISE(auto projected_batch,
lance::arrow::ApplyProjection(batch, projected_schema));
return ScanBatch{projected_batch, batch_id, offset, indices};
}

} // namespace lance::io::exec
3 changes: 0 additions & 3 deletions cpp/src/lance/io/exec/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ struct ScanBatch {

/// The length of this batch.
int64_t length() const;

/// Project selected columns over the Scanned Batch.
::arrow::Result<ScanBatch> Project(const lance::format::Schema& projected_schema);
};

/// I/O execute base node.
Expand Down
2 changes: 2 additions & 0 deletions python/lance/pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,6 @@ def __iter__(self):
]
if self.transform is not None:
record = self.transform(*record)
if batch.num_columns == 1:
record = record[0]
yield record
11 changes: 11 additions & 0 deletions python/lance/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,14 @@ def test_data_loader_with_filter(tmp_path: Path):
assert torch.is_tensor(id)
assert (value - 10) % 2 == 0
assert torch.is_tensor(value)

def test_data_loader_projection(tmp_path: Path):
ids = pa.array(range(10))
values = pa.array([f"num-{i}" for i in ids])
tab = pa.Table.from_arrays([ids, values], names=["id", "value"])
lance.write_table(tab, tmp_path / "lance")

dataset = LanceDataset(tmp_path / "lance", columns=["value"], filter=pc.field("id") >= 5)
for elem, expected_id in zip(dataset, range(5, 10)):
assert elem == f"num-{expected_id}"

0 comments on commit 9a6a4d7

Please sign in to comment.