Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursively set dictionaries in struct fields #662

Merged
merged 8 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions rust/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use std::fmt::{self};
use arrow_array::types::{
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{cast::as_dictionary_array, ArrayRef, RecordBatch};
use arrow_array::{
cast::as_dictionary_array, Array, ArrayRef, LargeListArray, ListArray, RecordBatch, StructArray,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, TimeUnit};
use async_recursion::async_recursion;

Expand Down Expand Up @@ -299,6 +301,7 @@ impl Field {
/// the dictionary to the manifest.
pub(crate) fn set_dictionary_values(&mut self, arr: &ArrayRef) {
assert!(self.data_type().is_dictionary());
// offset / length are set to 0 and recomputed when the dictionary is persisted to disk
self.dictionary = Some(Dictionary {
offset: 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should set the offset as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment clarifying when / how these fields are set

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

length: 0,
Expand Down Expand Up @@ -338,8 +341,19 @@ impl Field {
panic!("Unsupported dictionary key type: {}", key_type);
}
},
DataType::Struct(mut subfields) => {
for (i, f) in subfields.iter_mut().enumerate() {
let lance_field = self
.children
.iter_mut()
.find(|c| c.name == f.name().to_string())
.unwrap();
let struct_arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
lance_field.set_dictionary(struct_arr.column(i));
}
}
_ => {
// Add nested struct support.
// Add list / large list support.
}
}
}
Expand Down
60 changes: 58 additions & 2 deletions rust/src/io/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,15 +539,18 @@ async fn read_large_list_array(
mod tests {
use super::*;

use crate::dataset::{Dataset, WriteParams};
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::{
builder::{Int32Builder, ListBuilder, StringBuilder},
cast::{as_primitive_array, as_string_array, as_struct_array},
types::UInt8Type,
Array, DictionaryArray, Float32Array, Int64Array, NullArray, StringArray, StructArray,
UInt32Array, UInt8Array,
Array, DictionaryArray, Float32Array, Int64Array, NullArray, RecordBatchReader,
StringArray, StructArray, UInt32Array, UInt8Array,
};
use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
use futures::StreamExt;
use tempfile::tempdir;

use crate::io::FileWriter;

Expand Down Expand Up @@ -867,6 +870,59 @@ mod tests {
(arrow_schema, struct_array)
}

#[tokio::test]
async fn test_read_struct_of_dictionary_arrays() {
let test_dir = tempdir().unwrap();

let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"s",
DataType::Struct(vec![ArrowField::new(
"d",
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
false,
)]),
true,
)]));

let mut dict_builder = StringDictionaryBuilder::<UInt8Type>::new();
dict_builder.append("a").unwrap();
dict_builder.append("b").unwrap();
dict_builder.append("c").unwrap();

let struct_array = Arc::new(StructArray::from(vec![(
ArrowField::new(
"d",
DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)),
false,
),
Arc::new(dict_builder.finish()) as ArrayRef,
)]));

let batch = RecordBatch::try_new(arrow_schema.clone(), vec![struct_array.clone()]).unwrap();

let test_uri = test_dir.path().to_str().unwrap();

let batches = crate::arrow::RecordBatchBuffer::new(vec![batch.clone()]);
let mut batches: Box<dyn RecordBatchReader> = Box::new(batches);
Dataset::write(&mut batches, test_uri, Some(WriteParams::default()))
.await
.unwrap();

let result = scan_dataset(test_uri).await.unwrap();
assert_eq!(batch, result.as_slice()[0]);
}

async fn scan_dataset(uri: &str) -> Result<Vec<RecordBatch>> {
let results = Dataset::open(uri)
.await?
.scan()
.try_into_stream()
.await?
.try_collect::<Vec<_>>()
.await?;
Ok(results)
}

#[tokio::test]
async fn test_read_nullable_arrays() {
use arrow_array::Array;
Expand Down