diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index e556e7d897..bfb7722d3e 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -8,7 +8,9 @@ use std::fmt::{self}; use arrow_array::types::{ Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::{cast::as_dictionary_array, ArrayRef, RecordBatch}; +use arrow_array::{ + cast::as_dictionary_array, Array, ArrayRef, LargeListArray, ListArray, RecordBatch, StructArray, +}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, TimeUnit}; use async_recursion::async_recursion; @@ -299,6 +301,7 @@ impl Field { /// the dictionary to the manifest. pub(crate) fn set_dictionary_values(&mut self, arr: &ArrayRef) { assert!(self.data_type().is_dictionary()); + // offset / length are set to 0 and recomputed when the dictionary is persisted to disk self.dictionary = Some(Dictionary { offset: 0, length: 0, @@ -338,8 +341,19 @@ impl Field { panic!("Unsupported dictionary key type: {}", key_type); } }, + DataType::Struct(mut subfields) => { + for (i, f) in subfields.iter_mut().enumerate() { + let lance_field = self + .children + .iter_mut() + .find(|c| c.name == f.name().to_string()) + .unwrap(); + let struct_arr = arr.as_any().downcast_ref::().unwrap(); + lance_field.set_dictionary(struct_arr.column(i)); + } + } _ => { - // Add nested struct support. + // Add list / large list support. } } } diff --git a/rust/src/io/reader.rs b/rust/src/io/reader.rs index 167037377c..1f1136d0f0 100644 --- a/rust/src/io/reader.rs +++ b/rust/src/io/reader.rs @@ -539,15 +539,18 @@ async fn read_large_list_array( mod tests { use super::*; + use crate::dataset::{Dataset, WriteParams}; + use arrow_array::builder::StringDictionaryBuilder; use arrow_array::{ builder::{Int32Builder, ListBuilder, StringBuilder}, cast::{as_primitive_array, as_string_array, as_struct_array}, types::UInt8Type, - Array, DictionaryArray, Float32Array, Int64Array, NullArray, StringArray, StructArray, - UInt32Array, UInt8Array, + Array, DictionaryArray, Float32Array, Int64Array, NullArray, RecordBatchReader, + StringArray, StructArray, UInt32Array, UInt8Array, }; use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; use futures::StreamExt; + use tempfile::tempdir; use crate::io::FileWriter; @@ -867,6 +870,59 @@ mod tests { (arrow_schema, struct_array) } + #[tokio::test] + async fn test_read_struct_of_dictionary_arrays() { + let test_dir = tempdir().unwrap(); + + let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "s", + DataType::Struct(vec![ArrowField::new( + "d", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]), + true, + )])); + + let mut dict_builder = StringDictionaryBuilder::::new(); + dict_builder.append("a").unwrap(); + dict_builder.append("b").unwrap(); + dict_builder.append("c").unwrap(); + + let struct_array = Arc::new(StructArray::from(vec![( + ArrowField::new( + "d", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + ), + Arc::new(dict_builder.finish()) as ArrayRef, + )])); + + let batch = RecordBatch::try_new(arrow_schema.clone(), vec![struct_array.clone()]).unwrap(); + + let test_uri = test_dir.path().to_str().unwrap(); + + let batches = crate::arrow::RecordBatchBuffer::new(vec![batch.clone()]); + let mut batches: Box = Box::new(batches); + Dataset::write(&mut batches, test_uri, Some(WriteParams::default())) + .await + .unwrap(); + + let result = scan_dataset(test_uri).await.unwrap(); + assert_eq!(batch, result.as_slice()[0]); + } + + async fn scan_dataset(uri: &str) -> Result> { + let results = Dataset::open(uri) + .await? + .scan() + .try_into_stream() + .await? + .try_collect::>() + .await?; + Ok(results) + } + #[tokio::test] async fn test_read_nullable_arrays() { use arrow_array::Array;