Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Row AVG accumulator support Decimal type #5973

Merged
merged 3 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,31 @@ SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table;
----
NULL

# Creating the decimal table
statement ok
CREATE TABLE test_decimal_table (c1 INT, c2 DECIMAL(5, 2), c3 DECIMAL(5, 1), c4 DECIMAL(5, 1))

# Inserting data
statement ok
INSERT INTO test_decimal_table VALUES (1, 10.10, 100.1, NULL), (1, 20.20, 200.2, NULL), (2, 10.10, 700.1, NULL), (2, 20.20, 700.1, NULL), (3, 10.1, 100.1, NULL), (3, 10.1, NULL, NULL)

# aggregate_decimal_with_group_by
query IIRRRRIIR rowsort
select c1, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c3), count(c4), sum(c4) from test_decimal_table group by c1
----
1 2 15.15 30.3 10.1 20.2 2 0 NULL
2 2 15.15 30.3 10.1 20.2 2 0 NULL
3 2 10.1 20.2 10.1 10.1 1 0 NULL

# aggregate_decimal_with_group_by_decimal
query RIRRRRIR rowsort
select c3, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c4), sum(c4) from test_decimal_table group by c3
----
100.1 2 10.1 20.2 10.1 10.1 0 NULL
200.2 1 20.2 20.2 20.2 20.2 0 NULL
700.1 2 15.15 30.3 10.1 20.2 0 NULL
NULL 1 10.1 10.1 10.1 10.1 0 NULL

# Restore the default dialect
statement ok
set datafusion.sql_parser.dialect = 'Generic';
59 changes: 45 additions & 14 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ impl AggregateExpr for Avg {
) -> Result<Box<dyn RowAccumulator>> {
Ok(Box::new(AvgRowAccumulator::new(
start_index,
self.sum_data_type.clone(),
&self.sum_data_type,
&self.rt_data_type,
)))
}

Expand Down Expand Up @@ -236,7 +237,7 @@ impl Accumulator for AvgAccumulator {
})
}
_ => Err(DataFusionError::Internal(
"Sum should be f64 on average".to_string(),
"Sum should be f64 or decimal128 on average".to_string(),
)),
}
}
Expand All @@ -250,13 +251,19 @@ impl Accumulator for AvgAccumulator {
struct AvgRowAccumulator {
state_index: usize,
sum_datatype: DataType,
return_data_type: DataType,
}

impl AvgRowAccumulator {
pub fn new(start_index: usize, sum_datatype: DataType) -> Self {
pub fn new(
start_index: usize,
sum_datatype: &DataType,
return_data_type: &DataType,
) -> Self {
Self {
state_index: start_index,
sum_datatype,
sum_datatype: sum_datatype.clone(),
return_data_type: return_data_type.clone(),
}
}
}
Expand Down Expand Up @@ -298,16 +305,40 @@ impl RowAccumulator for AvgRowAccumulator {
}

fn evaluate(&self, accessor: &RowAccessor) -> Result<ScalarValue> {
assert_eq!(self.sum_datatype, DataType::Float64);
Ok(match accessor.get_u64_opt(self.state_index()) {
None => ScalarValue::Float64(None),
Some(0) => ScalarValue::Float64(None),
Some(n) => ScalarValue::Float64(
accessor
.get_f64_opt(self.state_index() + 1)
.map(|f| f / n as f64),
),
})
match self.sum_datatype {
DataType::Decimal128(p, s) => {
match accessor.get_u64_opt(self.state_index()) {
None => Ok(ScalarValue::Decimal128(None, p, s)),
Some(0) => Ok(ScalarValue::Decimal128(None, p, s)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we translate 0 --> null here? (it is also done for Float64 below)?

I see you are just following the existing pattern, but it seems like this could be incorrect?

Maybe we could add a test that calls AVG on (-1 and 1) to see if we get 0 or NULL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do some test on this today.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @alamb, the value at the state_index is for the count rather than the sum. When the count is 0, for the average, it should be NULL.

Some(n) => {
// now the sum_type and return type is not the same, need to convert the sum type to return type
accessor.get_i128_opt(self.state_index() + 1).map_or_else(
|| Ok(ScalarValue::Decimal128(None, p, s)),
|f| {
calculate_result_decimal_for_avg(
f,
n as i128,
s,
&self.return_data_type,
)
},
)
}
}
}
DataType::Float64 => Ok(match accessor.get_u64_opt(self.state_index()) {
None => ScalarValue::Float64(None),
Some(0) => ScalarValue::Float64(None),
Some(n) => ScalarValue::Float64(
accessor
.get_f64_opt(self.state_index() + 1)
.map(|f| f / n as f64),
),
}),
_ => Err(DataFusionError::Internal(
"Sum should be f64 or decimal128 on average".to_string(),
)),
}
}

#[inline(always)]
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ macro_rules! min_max_v2 {
ScalarValue::Int8(rhs) => {
typed_min_max_v2!($INDEX, $ACC, rhs, i8, $OP)
}
ScalarValue::Decimal128(rhs, ..) => {
typed_min_max_v2!($INDEX, $ACC, rhs, i128, $OP)
}
e => {
return Err(DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/aggregate/row_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ pub fn is_row_accumulator_support_dtype(data_type: &DataType) -> bool {
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
)
}
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ pub(crate) fn add_to_row(
ScalarValue::Int64(rhs) => {
sum_row!(index, accessor, rhs, i64)
}
ScalarValue::Decimal128(rhs, _, _) => {
sum_row!(index, accessor, rhs, i128)
}
_ => {
let msg =
format!("Row sum updater is not expected to receive a scalar {s:?}");
Expand Down
15 changes: 15 additions & 0 deletions datafusion/row/src/accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ impl<'a> RowAccessor<'a> {
fn_get_idx!(i64, 8);
fn_get_idx!(f32, 4);
fn_get_idx!(f64, 8);
fn_get_idx!(i128, 16);

fn_get_idx_opt!(bool);
fn_get_idx_opt!(u8);
Expand All @@ -205,6 +206,7 @@ impl<'a> RowAccessor<'a> {
fn_get_idx_opt!(i64);
fn_get_idx_opt!(f32);
fn_get_idx_opt!(f64);
fn_get_idx_opt!(i128);

fn_get_idx_scalar!(bool, Boolean);
fn_get_idx_scalar!(u8, UInt8);
Expand All @@ -218,6 +220,14 @@ impl<'a> RowAccessor<'a> {
fn_get_idx_scalar!(f32, Float32);
fn_get_idx_scalar!(f64, Float64);

fn get_decimal128_scalar(&self, idx: usize, p: u8, s: i8) -> ScalarValue {
if self.is_valid_at(idx) {
ScalarValue::Decimal128(Some(self.get_i128(idx)), p, s)
} else {
ScalarValue::Decimal128(None, p, s)
}
}

pub fn get_as_scalar(&self, dt: &DataType, index: usize) -> ScalarValue {
match dt {
DataType::Boolean => self.get_bool_scalar(index),
Expand All @@ -231,6 +241,7 @@ impl<'a> RowAccessor<'a> {
DataType::UInt64 => self.get_u64_scalar(index),
DataType::Float32 => self.get_f32_scalar(index),
DataType::Float64 => self.get_f64_scalar(index),
DataType::Decimal128(p, s) => self.get_decimal128_scalar(index, *p, *s),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -264,6 +275,7 @@ impl<'a> RowAccessor<'a> {
fn_set_idx!(i64, 8);
fn_set_idx!(f32, 4);
fn_set_idx!(f64, 8);
fn_set_idx!(i128, 16);

fn set_i8(&mut self, idx: usize, value: i8) {
self.assert_index_valid(idx);
Expand All @@ -285,6 +297,7 @@ impl<'a> RowAccessor<'a> {
fn_add_idx!(i64);
fn_add_idx!(f32);
fn_add_idx!(f64);
fn_add_idx!(i128);

fn_max_min_idx!(u8, max);
fn_max_min_idx!(u16, max);
Expand All @@ -296,6 +309,7 @@ impl<'a> RowAccessor<'a> {
fn_max_min_idx!(i64, max);
fn_max_min_idx!(f32, max);
fn_max_min_idx!(f64, max);
fn_max_min_idx!(i128, max);

fn_max_min_idx!(u8, min);
fn_max_min_idx!(u16, min);
Expand All @@ -307,4 +321,5 @@ impl<'a> RowAccessor<'a> {
fn_max_min_idx!(i64, min);
fn_max_min_idx!(f32, min);
fn_max_min_idx!(f64, min);
fn_max_min_idx!(i128, min);
}
13 changes: 8 additions & 5 deletions datafusion/row/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ fn word_aligned_offsets(null_width: usize, schema: &Schema) -> (Vec<usize>, usiz
let mut offset = null_width;
for f in schema.fields() {
offsets.push(offset);
assert!(!matches!(f.data_type(), DataType::Decimal128(_, _)));
// All of the current support types can fit into one single 8-bytes word.
// When we decide to support Decimal type in the future, its width would be
// of two 8-bytes words and should adapt the width calculation below.
offset += 8;
assert!(!matches!(f.data_type(), DataType::Decimal256(_, _)));
// All of the current support types can fit into one single 8-bytes word except for Decimal128.
// For Decimal128, its width is of two 8-bytes words.
match f.data_type() {
DataType::Decimal128(_, _) => offset += 16,
_ => offset += 8,
}
}
(offsets, offset - null_width)
}
Expand Down Expand Up @@ -241,6 +243,7 @@ fn supported_type(dt: &DataType, row_type: RowType) -> bool {
| Float64
| Date32
| Date64
| Decimal128(_, _)
)
}
}
Expand Down
15 changes: 15 additions & 0 deletions datafusion/row/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ impl<'a> RowReader<'a> {
get_idx!(i64, self, idx, 8)
}

fn get_decimal128(&self, idx: usize) -> i128 {
get_idx!(i128, self, idx, 16)
}

fn get_utf8(&self, idx: usize) -> &str {
self.assert_index_valid(idx);
let offset_size = self.get_u64(idx);
Expand Down Expand Up @@ -260,6 +264,14 @@ impl<'a> RowReader<'a> {
}
}

fn get_decimal128_opt(&self, idx: usize) -> Option<i128> {
if self.is_valid_at(idx) {
Some(self.get_decimal128(idx))
} else {
None
}
}

fn get_utf8_opt(&self, idx: usize) -> Option<&str> {
if self.is_valid_at(idx) {
Some(self.get_utf8(idx))
Expand Down Expand Up @@ -328,6 +340,7 @@ fn_read_field!(f64, Float64Builder);
fn_read_field!(date32, Date32Builder);
fn_read_field!(date64, Date64Builder);
fn_read_field!(utf8, StringBuilder);
fn_read_field!(decimal128, Decimal128Builder);

pub(crate) fn read_field_binary(
to: &mut Box<dyn ArrayBuilder>,
Expand Down Expand Up @@ -374,6 +387,7 @@ fn read_field(
Date64 => read_field_date64(to, col_idx, row),
Utf8 => read_field_utf8(to, col_idx, row),
Binary => read_field_binary(to, col_idx, row),
Decimal128(_, _) => read_field_decimal128(to, col_idx, row),
_ => unimplemented!(),
}
}
Expand Down Expand Up @@ -401,6 +415,7 @@ fn read_field_null_free(
Date64 => read_field_date64_null_free(to, col_idx, row),
Utf8 => read_field_utf8_null_free(to, col_idx, row),
Binary => read_field_binary_null_free(to, col_idx, row),
Decimal128(_, _) => read_field_decimal128_null_free(to, col_idx, row),
_ => unimplemented!(),
}
}
18 changes: 17 additions & 1 deletion datafusion/row/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw};
use datafusion_common::cast::{
as_binary_array, as_date32_array, as_date64_array, as_string_array,
as_binary_array, as_date32_array, as_date64_array, as_decimal128_array,
as_string_array,
};
use datafusion_common::Result;
use std::cmp::max;
Expand Down Expand Up @@ -225,6 +226,10 @@ impl RowWriter {
set_idx!(8, self, idx, value)
}

fn set_decimal128(&mut self, idx: usize, value: i128) {
set_idx!(16, self, idx, value)
}

fn set_offset_size(&mut self, idx: usize, size: u32) {
let offset_and_size: u64 = (self.varlena_offset as u64) << 32 | (size as u64);
self.set_u64(idx, offset_and_size);
Expand Down Expand Up @@ -375,6 +380,16 @@ pub(crate) fn write_field_binary(
to.set_binary(col_idx, s);
}

pub(crate) fn write_field_decimal128(
to: &mut RowWriter,
from: &Arc<dyn Array>,
col_idx: usize,
row_idx: usize,
) {
let from = as_decimal128_array(from).unwrap();
to.set_decimal128(col_idx, from.value(row_idx));
}

fn write_field(
col_idx: usize,
row_idx: usize,
Expand All @@ -399,6 +414,7 @@ fn write_field(
Date64 => write_field_date64(row, col, col_idx, row_idx),
Utf8 => write_field_utf8(row, col, col_idx, row_idx),
Binary => write_field_binary(row, col, col_idx, row_idx),
Decimal128(_, _) => write_field_decimal128(row, col, col_idx, row_idx),
_ => unimplemented!(),
}
}