Skip to content

Commit

Permalink
add test for reading decimal value from primitive array reader
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Aug 11, 2022
1 parent 27f4762 commit 435458c
Showing 1 changed file with 124 additions and 17 deletions.
141 changes: 124 additions & 17 deletions parquet/src/arrow/array_reader/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ use std::sync::Arc;
/// Primitive array readers are leaves of array reader tree. They accept page iterator
/// and read them into primitive arrays.
pub struct PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
where
T: DataType,
T::T: ScalarValue,
{
data_type: ArrowType,
pages: Box<dyn PageIterator>,
Expand All @@ -48,9 +48,9 @@ where
}

impl<T> PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
where
T: DataType,
T::T: ScalarValue,
{
/// Construct primitive array reader.
pub fn new(
Expand Down Expand Up @@ -80,9 +80,9 @@ where

/// Implementation of primitive array reader.
impl<T> ArrayReader for PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
where
T: DataType,
T::T: ScalarValue,
{
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -203,10 +203,10 @@ where
return Err(arrow_err!(
"Cannot convert {:?} to decimal",
array.data_type()
))
));
}
}
.with_precision_and_scale(p, s)?;
.with_precision_and_scale(p, s)?;

Arc::new(array) as ArrayRef
}
Expand Down Expand Up @@ -239,16 +239,17 @@ mod tests {
use crate::arrow::array_reader::test_util::EmptyPageIterator;
use crate::basic::Encoding;
use crate::column::page::Page;
use crate::data_type::Int32Type;
use crate::data_type::{Int32Type, Int64Type};
use crate::schema::parser::parse_message_type;
use crate::schema::types::SchemaDescriptor;
use crate::util::test_common::rand_gen::make_pages;
use crate::util::InMemoryPageIterator;
use arrow::array::PrimitiveArray;
use arrow::datatypes::ArrowPrimitiveType;
use arrow::array::{Array, PrimitiveArray};
use arrow::datatypes::{ArrowPrimitiveType};

use rand::distributions::uniform::SampleUniform;
use std::collections::VecDeque;
use arrow::datatypes::DataType::Decimal128;

#[allow(clippy::too_many_arguments)]
fn make_column_chunks<T: DataType>(
Expand Down Expand Up @@ -314,7 +315,7 @@ mod tests {
column_desc,
None,
)
.unwrap();
.unwrap();

// expect no values to be read
let array = array_reader.next_batch(50).unwrap();
Expand Down Expand Up @@ -361,7 +362,7 @@ mod tests {
column_desc,
None,
)
.unwrap();
.unwrap();

// Read first 50 values, which are all from the first column chunk
let array = array_reader.next_batch(50).unwrap();
Expand Down Expand Up @@ -561,7 +562,7 @@ mod tests {
column_desc,
None,
)
.unwrap();
.unwrap();

let mut accu_len: usize = 0;

Expand Down Expand Up @@ -602,4 +603,110 @@ mod tests {
);
}
}


#[test]
fn test_primitive_array_reader_decimal_types() {
// parquet `INT32` to decimal
let message_type = "
message test_schema {
REQUIRED INT32 decimal1 (DECIMAL(8,2));
}
";
let schema = parse_message_type(message_type)
.map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t))))
.unwrap();
let column_desc = schema.column(0);

// create the array reader
{
let mut data = Vec::new();
let mut page_lists = Vec::new();
make_column_chunks::<Int32Type>(
column_desc.clone(),
Encoding::PLAIN,
100,
-99999999,
99999999,
&mut Vec::new(),
&mut Vec::new(),
&mut data,
&mut page_lists,
true,
2,
);
let page_iterator =
InMemoryPageIterator::new(schema, column_desc.clone(), page_lists);

let mut array_reader = PrimitiveArrayReader::<Int32Type>::new(
Box::new(page_iterator),
column_desc,
None,
)
.unwrap();

// read data from the reader
// the data type is decimal(8,2)
let array = array_reader.next_batch(50).unwrap();
assert_eq!(array.data_type(), &Decimal128(8, 2));
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(8, 2).unwrap();
assert_eq!(array, &data_decimal_array);

// not equal with different data type(precision and scale)
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(9, 0).unwrap();
assert_ne!(array, &data_decimal_array)
}

// parquet `INT64` to decimal
let message_type = "
message test_schema {
REQUIRED INT64 decimal1 (DECIMAL(18,4));
}
";
let schema = parse_message_type(message_type)
.map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t))))
.unwrap();
let column_desc = schema.column(0);

// create the array reader
{
let mut data = Vec::new();
let mut page_lists = Vec::new();
make_column_chunks::<Int64Type>(
column_desc.clone(),
Encoding::PLAIN,
100,
-999999999999999999,
999999999999999999,
&mut Vec::new(),
&mut Vec::new(),
&mut data,
&mut page_lists,
true,
2,
);
let page_iterator =
InMemoryPageIterator::new(schema, column_desc.clone(), page_lists);

let mut array_reader = PrimitiveArrayReader::<Int64Type>::new(
Box::new(page_iterator),
column_desc,
None,
)
.unwrap();

// read data from the reader
// the data type is decimal(18,4)
let array = array_reader.next_batch(50).unwrap();
assert_eq!(array.data_type(), &Decimal128(18, 4));
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(18, 4).unwrap();
assert_eq!(array, &data_decimal_array);

// not equal with different data type(precision and scale)
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(34, 0).unwrap();
assert_ne!(array, &data_decimal_array)
}
}
}

0 comments on commit 435458c

Please sign in to comment.