diff --git a/src/array/fmt.rs b/src/array/fmt.rs index 39f4f8b1603..47b72fb9cac 100644 --- a/src/array/fmt.rs +++ b/src/array/fmt.rs @@ -88,7 +88,9 @@ pub fn get_value_display<'a, F: Write + 'a>( Union => Box::new(move |f, index| { super::union::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) }), - Map => todo!(), + Map => Box::new(move |f, index| { + super::map::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), Dictionary(key_type) => match_integer_type!(key_type, |$T| { Box::new(move |f, index| { super::dictionary::fmt::write_value::<$T,_>(array.as_any().downcast_ref().unwrap(), index, null, f) diff --git a/src/array/map/mod.rs b/src/array/map/mod.rs index 51ca56a8459..aaeaba6745c 100644 --- a/src/array/map/mod.rs +++ b/src/array/map/mod.rs @@ -8,9 +8,11 @@ use crate::{ use super::{new_empty_array, specification::try_check_offsets, Array}; mod ffi; -mod fmt; +pub(super) mod fmt; mod iterator; pub use iterator::*; +mod mutable; +pub use mutable::*; /// An array representing a (key, value), both of arbitrary logical types. #[derive(Clone)] diff --git a/src/array/map/mutable.rs b/src/array/map/mutable.rs new file mode 100644 index 00000000000..e614217d77e --- /dev/null +++ b/src/array/map/mutable.rs @@ -0,0 +1,192 @@ +use std::sync::Arc; + +use crate::{ + array::{Array, MutableArray, MutableStructArray, StructArray}, + bitmap::MutableBitmap, + datatypes::DataType, + error::{Error, Result}, + types::Index, +}; + +use super::MapArray; + +/// The mutable version lf [`MapArray`]. +#[derive(Debug)] +pub struct MutableMapArray { + data_type: DataType, + offsets: Vec, + field: MutableStructArray, + validity: Option, +} + +impl From for MapArray { + fn from(other: MutableMapArray) -> Self { + let validity = if other.validity.as_ref().map(|x| x.unset_bits()).unwrap_or(0) > 0 { + other.validity.map(|x| x.into()) + } else { + None + }; + + let field: StructArray = other.field.into(); + + MapArray::from_data( + other.data_type, + other.offsets.into(), + Box::new(field), + validity, + ) + } +} + +impl MutableMapArray { + /// Creates a new empty [`MutableMapArray`]. + /// # Errors + /// This function errors if: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Map`] + /// * The fields' `data_type` is not equal to the inner field of `data_type` + pub fn try_new(data_type: DataType, values: Vec>) -> Result { + let field = MutableStructArray::try_from_data( + MapArray::get_field(&data_type).data_type().clone(), + values, + None, + )?; + Ok(Self { + data_type, + offsets: vec![0i32; 1], + field, + validity: None, + }) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len() - 1 + } + + /// Mutable reference to the field + pub fn mut_field(&mut self) -> &mut MutableStructArray { + &mut self.field + } + + /// Reference to the field + pub fn field(&self) -> &MutableStructArray { + &self.field + } + + /// Get a mutable reference to the keys and values arrays + pub fn keys_values(&mut self) -> Option<(&mut K, &mut V)> + where + K: MutableArray + 'static, + V: MutableArray + 'static, + { + let [keys, values]: &mut [_; 2] = + self.field.mut_values().as_mut_slice().try_into().unwrap(); + + match ( + keys.as_mut_any().downcast_mut(), + values.as_mut_any().downcast_mut(), + ) { + (Some(keys), Some(values)) => Some((keys, values)), + _ => None, + } + } + + /// Call this once for each "row" of children you push. + pub fn push(&mut self, valid: bool) { + match &mut self.validity { + Some(validity) => validity.push(valid), + None => match valid { + true => (), + false => self.init_validity(), + }, + } + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + pub fn try_push_valid(&mut self) -> Result<()> { + let size = self.field.len(); + let size = ::from_usize(size).ok_or(Error::Overflow)?; + assert!(size >= *self.offsets.last().unwrap()); + self.offsets.push(size); + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + fn push_null(&mut self) { + self.field.push(false); + self.push(false); + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.field.len()); + let len = self.len(); + if len > 0 { + validity.extend_constant(len, true); + validity.set(len - 1, false); + } + self.validity = Some(validity); + } + + fn take_into(&mut self) -> MapArray { + MapArray::from_data( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.field.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + } +} + +impl MutableArray for MutableMapArray { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + Box::new(self.take_into()) + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.take_into()) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + self.field.reserve(additional); + if let Some(validity) = &mut self.validity { + validity.reserve(additional) + } + } + + fn shrink_to_fit(&mut self) { + self.offsets.shrink_to_fit(); + self.field.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit(); + } + } +} diff --git a/src/array/mod.rs b/src/array/mod.rs index 45e83a80803..a25e21e65c0 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -446,7 +446,7 @@ pub use dictionary::{DictionaryArray, DictionaryKey, MutableDictionaryArray}; pub use fixed_size_binary::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray}; pub use fixed_size_list::{FixedSizeListArray, MutableFixedSizeListArray}; pub use list::{ListArray, ListValuesIter, MutableListArray}; -pub use map::MapArray; +pub use map::{MapArray, MutableMapArray}; pub use null::NullArray; pub use primitive::*; pub use struct_::{MutableStructArray, StructArray}; diff --git a/src/array/struct_/mutable.rs b/src/array/struct_/mutable.rs index b35e1064b94..d3b4c577e1b 100644 --- a/src/array/struct_/mutable.rs +++ b/src/array/struct_/mutable.rs @@ -4,6 +4,7 @@ use crate::{ array::{Array, MutableArray}, bitmap::MutableBitmap, datatypes::DataType, + error::Error, }; use super::StructArray; @@ -38,44 +39,77 @@ impl MutableStructArray { Self::from_data(data_type, values, None) } - /// Create a [`MutableStructArray`] out of low-end APIs. - /// # Panics - /// This function panics iff: + /// Fallibly create a [`MutableStructArray`] out of low-level APIs. + /// # Errors + /// This function returns an error if: /// * `data_type` is not [`DataType::Struct`] /// * The inner types of `data_type` are not equal to those of `values` /// * `validity` is not `None` and its length is different from the `values`'s length - pub fn from_data( + pub fn try_from_data( data_type: DataType, values: Vec>, validity: Option, - ) -> Self { + ) -> Result { match data_type.to_logical_type() { - DataType::Struct(ref fields) => assert!(fields - .iter() - .map(|f| f.data_type()) - .eq(values.iter().map(|f| f.data_type()))), - _ => panic!("StructArray must be initialized with DataType::Struct"), - }; + DataType::Struct(ref fields) => { + if !fields + .iter() + .map(|f| f.data_type()) + .eq(values.iter().map(|f| f.data_type())) + { + Err(crate::error::Error::InvalidArgumentError( + "DataType::Struct fields must match those found in `values`.".to_owned(), + )) + } else { + Ok(()) + } + } + _ => Err(crate::error::Error::InvalidArgumentError( + "StructArray must be initialized with DataType::Struct".to_owned(), + )), + }?; let self_ = Self { data_type, values, validity, }; - self_.assert_lengths(); - self_ + self_.check_lengths().map(|_| self_) } - fn assert_lengths(&self) { + /// Create a [`MutableStructArray`] out of low-end APIs. + /// # Panics + /// This function panics iff: + /// * `data_type` is not [`DataType::Struct`] + /// * The inner types of `data_type` are not equal to those of `values` + /// * `validity` is not `None` and its length is different from the `values`'s length + pub fn from_data( + data_type: DataType, + values: Vec>, + validity: Option, + ) -> Self { + Self::try_from_data(data_type, values, validity).unwrap() + } + + fn check_lengths(&self) -> Result<(), Error> { let first_len = self.values.first().map(|v| v.len()); if let Some(len) = first_len { if !self.values.iter().all(|x| x.len() == len) { let lengths: Vec<_> = self.values.iter().map(|v| v.len()).collect(); - panic!("StructArray child lengths differ: {:?}", lengths); + return Err(Error::InvalidArgumentError(format!( + "StructArray child lengths differ: {lengths:?}" + ))); } } if let Some(validity) = &self.validity { - assert_eq!(first_len.unwrap_or(0), validity.len()); + let struct_len = first_len.unwrap_or(0); + let validity_len = validity.len(); + if struct_len != validity_len { + return Err(Error::InvalidArgumentError(format!( + "StructArray child lengths ({struct_len}) differ from validity ({validity_len})", + ))); + } } + Ok(()) } /// Extract the low-end APIs from the [`MutableStructArray`]. diff --git a/tests/it/array/map/mod.rs b/tests/it/array/map/mod.rs index 38fde84367e..262690519a9 100644 --- a/tests/it/array/map/mod.rs +++ b/tests/it/array/map/mod.rs @@ -1,3 +1,5 @@ +mod mutable; + use arrow2::{ array::*, datatypes::{DataType, Field}, diff --git a/tests/it/array/map/mutable.rs b/tests/it/array/map/mutable.rs new file mode 100644 index 00000000000..2162551ebf2 --- /dev/null +++ b/tests/it/array/map/mutable.rs @@ -0,0 +1,61 @@ +use arrow2::{ + array::{MapArray, MutableArray, MutableMapArray, MutableUtf8Array}, + datatypes::{DataType, Field}, +}; + +#[test] +fn basics() { + let dt = DataType::Struct(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, true), + ]); + let data_type = DataType::Map(Box::new(Field::new("a", dt, true)), false); + + let values = vec![ + Box::new(MutableUtf8Array::::new()) as Box, + Box::new(MutableUtf8Array::::new()) as Box, + ]; + + let mut array = MutableMapArray::try_new(data_type, values).unwrap(); + assert_eq!(array.len(), 0); + + let field = array.mut_field(); + field + .value::>(0) + .unwrap() + .extend([Some("a"), Some("aa"), Some("aaa")]); + field + .value::>(1) + .unwrap() + .extend([Some("b"), Some("bb"), Some("bbb")]); + array.try_push_valid().unwrap(); + assert_eq!(array.len(), 1); + + array.keys::>().push(Some("foo")); + array.values::>().push(Some("bar")); + array.try_push_valid().unwrap(); + assert_eq!(array.len(), 2); + + let map: MapArray = array.into(); + dbg!(map); +} + +#[test] +fn failure() { + let dt = DataType::Struct(vec![ + Field::new("key", DataType::Utf8, true), + Field::new("value", DataType::Utf8, true), + Field::new("extra", DataType::Utf8, true), + ]); + let data_type = DataType::Map(Box::new(Field::new("item", dt, true)), false); + + let values = vec![ + Box::new(MutableUtf8Array::::new()) as Box, + Box::new(MutableUtf8Array::::new()) as Box, + ]; + + assert!(matches!( + MutableMapArray::try_new(data_type, values), + Err(arrow2::error::Error::InvalidArgumentError(_)) + )); +}