diff --git a/rust/src/arrow.rs b/rust/src/arrow.rs index 2e3f02c13e..24c0fbf80b 100644 --- a/rust/src/arrow.rs +++ b/rust/src/arrow.rs @@ -19,12 +19,10 @@ //! //! To improve Arrow-RS egonomitic -use arrow_array::types::UInt8Type; use arrow_array::{ Array, FixedSizeBinaryArray, FixedSizeListArray, Int32Array, ListArray, UInt8Array, }; -use arrow_data::{ArrayData, ArrayDataBuilder}; -use arrow_schema::DataType::FixedSizeBinary; +use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field}; use crate::error::Result; diff --git a/rust/src/datatypes.rs b/rust/src/datatypes.rs index b4e85f0d01..440d57476f 100644 --- a/rust/src/datatypes.rs +++ b/rust/src/datatypes.rs @@ -319,6 +319,18 @@ impl Field { ) } + /// Recursively set field ID and parent ID for this field and all its children. + fn set_id(&mut self, parent_id: i32, id_seed: &mut i32) { + self.parent_id = parent_id; + if self.id < 0 { + self.id = *id_seed; + *id_seed += 1; + } + self.children + .iter_mut() + .for_each(|f| f.set_id(self.id, id_seed)); + } + // Find any nested child with a specific field id fn mut_field_by_id(&mut self, id: i32) -> Option<&mut Field> { for child in self.children.as_mut_slice() { @@ -558,6 +570,13 @@ impl Schema { } Ok(()) } + + fn set_field_id(&mut self) { + let mut current_id = self.max_field_id().unwrap_or(-1) + 1; + self.fields + .iter_mut() + .for_each(|f| f.set_id(-1, &mut current_id)); + } } impl fmt::Display for Schema { @@ -574,14 +593,17 @@ impl TryFrom<&ArrowSchema> for Schema { type Error = Error; fn try_from(schema: &ArrowSchema) -> Result { - Ok(Self { + let mut schema = Self { fields: schema .fields .iter() .map(Field::try_from) .collect::>()?, metadata: schema.metadata.clone(), - }) + }; + schema.set_field_id(); + + Ok(schema) } } @@ -629,10 +651,10 @@ impl From<&Schema> for Vec { #[cfg(test)] mod tests { - use arrow_schema::{Field as ArrowField, TimeUnit}; - use super::*; + use arrow_schema::{Field as ArrowField, TimeUnit}; + #[test] fn arrow_field_to_field() { for (name, data_type) in [ @@ -759,7 +781,7 @@ mod tests { ), ArrowField::new("c", DataType::Float64, false), ]); - assert_eq!(projected, Schema::try_from(&expected_arrow_schema).unwrap()); + assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema); } #[test] @@ -778,6 +800,43 @@ mod tests { ArrowField::new("c", DataType::Float64, false), ]); let schema = Schema::try_from(&arrow_schema).unwrap(); - let projected = schema.project_by_ids(&[1, 4, 5]).unwrap(); + let projected = schema.project_by_ids(&[1, 2, 4, 5]).unwrap(); + + let expected_arrow_schema = ArrowSchema::new(vec![ + ArrowField::new( + "b", + DataType::Struct(vec![ + ArrowField::new("f1", DataType::Utf8, true), + ArrowField::new("f3", DataType::Float32, false), + ]), + true, + ), + ArrowField::new("c", DataType::Float64, false), + ]); + assert_eq!(ArrowSchema::from(&projected), expected_arrow_schema); + } + + #[test] + fn test_schema_set_ids() { + let arrow_schema = ArrowSchema::new(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new( + "b", + DataType::Struct(vec![ + ArrowField::new("f1", DataType::Utf8, true), + ArrowField::new("f2", DataType::Boolean, false), + ArrowField::new("f3", DataType::Float32, false), + ]), + true, + ), + ArrowField::new("c", DataType::Float64, false), + ]); + let schema = Schema::try_from(&arrow_schema).unwrap(); + + let protos: Vec = (&schema).into(); + assert_eq!( + protos.iter().map(|p| p.id).collect::>(), + (0..6).collect::>() + ); } } diff --git a/rust/src/encodings/plain.rs b/rust/src/encodings/plain.rs index 1bc65caa74..00a0cc3707 100644 --- a/rust/src/encodings/plain.rs +++ b/rust/src/encodings/plain.rs @@ -224,7 +224,6 @@ impl<'a> Decoder for PlainDecoder<'a> { mod tests { use crate::io::ObjectStore; use arrow_array::*; - use arrow_schema::DataType::FixedSizeList; use arrow_schema::Field; use object_store::path::Path; use rand::prelude::*;