Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add test for reading decimal value from primitive array reader #2411

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 124 additions & 17 deletions parquet/src/arrow/array_reader/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ use std::sync::Arc;
/// Primitive array readers are leaves of array reader tree. They accept page iterator
/// and read them into primitive arrays.
pub struct PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
where
T: DataType,
T::T: ScalarValue,
{
data_type: ArrowType,
pages: Box<dyn PageIterator>,
Expand All @@ -48,9 +48,9 @@ where
}

impl<T> PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
where
T: DataType,
T::T: ScalarValue,
{
/// Construct primitive array reader.
pub fn new(
Expand Down Expand Up @@ -80,9 +80,9 @@ where

/// Implementation of primitive array reader.
impl<T> ArrayReader for PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
where
T: DataType,
T::T: ScalarValue,
{
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -203,10 +203,10 @@ where
return Err(arrow_err!(
"Cannot convert {:?} to decimal",
array.data_type()
))
));
}
}
.with_precision_and_scale(p, s)?;
.with_precision_and_scale(p, s)?;

Arc::new(array) as ArrayRef
}
Expand Down Expand Up @@ -239,16 +239,17 @@ mod tests {
use crate::arrow::array_reader::test_util::EmptyPageIterator;
use crate::basic::Encoding;
use crate::column::page::Page;
use crate::data_type::Int32Type;
use crate::data_type::{Int32Type, Int64Type};
use crate::schema::parser::parse_message_type;
use crate::schema::types::SchemaDescriptor;
use crate::util::test_common::rand_gen::make_pages;
use crate::util::InMemoryPageIterator;
use arrow::array::PrimitiveArray;
use arrow::datatypes::ArrowPrimitiveType;
use arrow::array::{Array, PrimitiveArray};
use arrow::datatypes::{ArrowPrimitiveType};

use rand::distributions::uniform::SampleUniform;
use std::collections::VecDeque;
use arrow::datatypes::DataType::Decimal128;

#[allow(clippy::too_many_arguments)]
fn make_column_chunks<T: DataType>(
Expand Down Expand Up @@ -314,7 +315,7 @@ mod tests {
column_desc,
None,
)
.unwrap();
.unwrap();

// expect no values to be read
let array = array_reader.next_batch(50).unwrap();
Expand Down Expand Up @@ -361,7 +362,7 @@ mod tests {
column_desc,
None,
)
.unwrap();
.unwrap();

// Read first 50 values, which are all from the first column chunk
let array = array_reader.next_batch(50).unwrap();
Expand Down Expand Up @@ -561,7 +562,7 @@ mod tests {
column_desc,
None,
)
.unwrap();
.unwrap();

let mut accu_len: usize = 0;

Expand Down Expand Up @@ -602,4 +603,110 @@ mod tests {
);
}
}


#[test]
fn test_primitive_array_reader_decimal_types() {
// parquet `INT32` to decimal
let message_type = "
message test_schema {
REQUIRED INT32 decimal1 (DECIMAL(8,2));
}
";
let schema = parse_message_type(message_type)
.map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t))))
.unwrap();
let column_desc = schema.column(0);

// create the array reader
{
let mut data = Vec::new();
let mut page_lists = Vec::new();
make_column_chunks::<Int32Type>(
column_desc.clone(),
Encoding::PLAIN,
100,
-99999999,
99999999,
&mut Vec::new(),
&mut Vec::new(),
&mut data,
&mut page_lists,
true,
2,
);
let page_iterator =
InMemoryPageIterator::new(schema, column_desc.clone(), page_lists);

let mut array_reader = PrimitiveArrayReader::<Int32Type>::new(
Box::new(page_iterator),
column_desc,
None,
)
.unwrap();

// read data from the reader
// the data type is decimal(8,2)
let array = array_reader.next_batch(50).unwrap();
assert_eq!(array.data_type(), &Decimal128(8, 2));
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(8, 2).unwrap();
assert_eq!(array, &data_decimal_array);

// not equal with different data type(precision and scale)
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(9, 0).unwrap();
assert_ne!(array, &data_decimal_array)
}

// parquet `INT64` to decimal
let message_type = "
message test_schema {
REQUIRED INT64 decimal1 (DECIMAL(18,4));
}
";
let schema = parse_message_type(message_type)
.map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t))))
.unwrap();
let column_desc = schema.column(0);

// create the array reader
{
let mut data = Vec::new();
let mut page_lists = Vec::new();
make_column_chunks::<Int64Type>(
column_desc.clone(),
Encoding::PLAIN,
100,
-999999999999999999,
999999999999999999,
&mut Vec::new(),
&mut Vec::new(),
&mut data,
&mut page_lists,
true,
2,
);
let page_iterator =
InMemoryPageIterator::new(schema, column_desc.clone(), page_lists);

let mut array_reader = PrimitiveArrayReader::<Int64Type>::new(
Box::new(page_iterator),
column_desc,
None,
)
.unwrap();

// read data from the reader
// the data type is decimal(18,4)
let array = array_reader.next_batch(50).unwrap();
assert_eq!(array.data_type(), &Decimal128(18, 4));
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(18, 4).unwrap();
assert_eq!(array, &data_decimal_array);

// not equal with different data type(precision and scale)
let data_decimal_array = data[0..50].iter().copied().map(|v| Some(v as i128)).collect::<Decimal128Array>().with_precision_and_scale(34, 0).unwrap();
assert_ne!(array, &data_decimal_array)
}
}
}