Skip to content

Commit

Permalink
Implement support for list of Dictionaries
Browse files Browse the repository at this point in the history
- Fixed a bug in write_manifest that would skip writing the dict values for the first field in the schema
  • Loading branch information
gsilvestrin committed Mar 21, 2023
1 parent 780e38b commit 4c3c350
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
14 changes: 12 additions & 2 deletions rust/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use std::fmt::{self};
use arrow_array::types::{
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{cast::as_dictionary_array, Array, ArrayRef, RecordBatch, StructArray};
use arrow_array::{
cast::as_dictionary_array, Array, ArrayRef, LargeListArray, ListArray, RecordBatch, StructArray,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, TimeUnit};
use async_recursion::async_recursion;

Expand Down Expand Up @@ -375,8 +377,16 @@ impl Field {
lance_field.set_dictionary(struct_arr.column(i));
}
}
DataType::List(_) => {
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap();
self.children[0].set_dictionary(list_arr.values());
}
DataType::LargeList(_) => {
let list_arr = arr.as_any().downcast_ref::<LargeListArray>().unwrap();
self.children[0].set_dictionary(list_arr.values());
}
_ => {
// Add list / large list support.
// Add list / large list support. - should we panic?
}
}
}
Expand Down
40 changes: 39 additions & 1 deletion rust/src/io/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub async fn write_manifest(
) -> Result<usize> {
// Write dictionary values.
let max_field_id = manifest.schema.max_field_id().unwrap_or(-1);
for field_id in 1..max_field_id + 1 {
for field_id in 0..max_field_id + 1 {
if let Some(field) = manifest.schema.mut_field_by_id(field_id) {
if field.data_type().is_dictionary() {
let dict_info = field.dictionary.as_mut().ok_or_else(|| {
Expand Down Expand Up @@ -414,6 +414,24 @@ mod tests {
DataType::LargeList(Box::new(ArrowField::new("item", DataType::Utf8, true))),
true,
),
ArrowField::new(
"l_dict",
DataType::List(Box::new(ArrowField::new(
"item",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
true,
))),
true,
),
ArrowField::new(
"large_l_dict",
DataType::LargeList(Box::new(ArrowField::new(
"item",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
true,
))),
true,
),
ArrowField::new(
"s",
DataType::Struct(vec![
Expand Down Expand Up @@ -452,6 +470,24 @@ mod tests {
let large_list_arr =
LargeListArray::try_new(large_list_values, &large_list_offsets).unwrap();

let list_dict_offsets = (0..202).step_by(2).collect();
let list_dict_vec = (0..200)
.into_iter()
.map(|n| ["a", "b", "c"][n % 3])
.collect::<Vec<_>>();
let list_dict_arr: DictionaryArray<UInt32Type> = list_dict_vec.into_iter().collect();
let list_dict_arr = ListArray::try_new(list_dict_arr, &list_dict_offsets).unwrap();

let large_list_dict_offsets: Int64Array = (0..202).step_by(2).collect();
let large_list_dict_vec = (0..200)
.into_iter()
.map(|n| ["a", "b", "c"][n % 3])
.collect::<Vec<_>>();
let large_list_dict_arr: DictionaryArray<UInt32Type> =
large_list_dict_vec.into_iter().collect();
let large_list_dict_arr =
LargeListArray::try_new(large_list_dict_arr, &large_list_dict_offsets).unwrap();

let columns: Vec<ArrayRef> = vec![
Arc::new(NullArray::new(100)),
Arc::new(BooleanArray::from_iter(
Expand Down Expand Up @@ -491,6 +527,8 @@ mod tests {
Arc::new(fixed_size_binary_arr),
Arc::new(list_arr),
Arc::new(large_list_arr),
Arc::new(list_dict_arr),
Arc::new(large_list_dict_arr),
Arc::new(StructArray::from(vec![
(
ArrowField::new("si", DataType::Int64, true),
Expand Down

0 comments on commit 4c3c350

Please sign in to comment.