Skip to content

Commit

Permalink
resolve review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mingmwang committed Apr 18, 2023
1 parent 5bfc141 commit 4a35fad
Showing 1 changed file with 6 additions and 221 deletions.
227 changes: 6 additions & 221 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,9 @@ use arrow::array::*;
use arrow::compute::{cast, filter};
use arrow::datatypes::{DataType, Schema, UInt32Type};
use arrow::{compute, datatypes::SchemaRef, record_batch::RecordBatch};
use arrow_array::types::{
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt64Type, UInt8Type,
};
use arrow_schema::{IntervalUnit, TimeUnit};
use datafusion_common::cast::{
as_boolean_array, as_decimal128_array, as_fixed_size_binary_array,
as_fixed_size_list_array, as_list_array, as_struct_array,
};
use datafusion_common::scalar::get_dict_value;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::utils::get_arrayref_at_indices;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;
use datafusion_row::layout::RowLayout;
Expand Down Expand Up @@ -866,23 +858,6 @@ fn slice_and_maybe_filter(
Ok(filtered_arrays)
}

macro_rules! typed_cast_to_scalar {
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
Ok(ScalarValue::$SCALAR(Some(array.value($index).into())))
}};
}

macro_rules! typed_cast_tz_to_scalar {
($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{
let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
Ok(ScalarValue::$SCALAR(
Some(array.value($index).into()),
$TZ.clone(),
))
}};
}

/// This method is similar to Scalar::try_from_array except for the Null handling.
/// This method returns [ScalarValue::Null] instead of [ScalarValue::Type(None)]
fn col_to_scalar(
Expand All @@ -898,199 +873,9 @@ fn col_to_scalar(
return Ok(ScalarValue::Null);
}
}
match array.data_type() {
DataType::Null => Ok(ScalarValue::Null),
DataType::Boolean => {
typed_cast_to_scalar!(array, row_index, BooleanArray, Boolean)
}
DataType::Int8 => typed_cast_to_scalar!(array, row_index, Int8Array, Int8),
DataType::Int16 => typed_cast_to_scalar!(array, row_index, Int16Array, Int16),
DataType::Int32 => typed_cast_to_scalar!(array, row_index, Int32Array, Int32),
DataType::Int64 => typed_cast_to_scalar!(array, row_index, Int64Array, Int64),
DataType::UInt8 => typed_cast_to_scalar!(array, row_index, UInt8Array, UInt8),
DataType::UInt16 => {
typed_cast_to_scalar!(array, row_index, UInt16Array, UInt16)
}
DataType::UInt32 => {
typed_cast_to_scalar!(array, row_index, UInt32Array, UInt32)
}
DataType::UInt64 => {
typed_cast_to_scalar!(array, row_index, UInt64Array, UInt64)
}
DataType::Float32 => {
typed_cast_to_scalar!(array, row_index, Float32Array, Float32)
}
DataType::Float64 => {
typed_cast_to_scalar!(array, row_index, Float64Array, Float64)
}
DataType::Decimal128(p, s) => {
let array = as_decimal128_array(array)?;
Ok(ScalarValue::Decimal128(
Some(array.value(row_index)),
*p,
*s,
))
}
DataType::Binary => {
typed_cast_to_scalar!(array, row_index, BinaryArray, Binary)
}
DataType::LargeBinary => {
typed_cast_to_scalar!(array, row_index, LargeBinaryArray, LargeBinary)
}
DataType::Utf8 => typed_cast_to_scalar!(array, row_index, StringArray, Utf8),
DataType::LargeUtf8 => {
typed_cast_to_scalar!(array, row_index, LargeStringArray, LargeUtf8)
}
DataType::List(nested_type) => {
let list_array = as_list_array(array)?;

let nested_array = list_array.value(row_index);
let scalar_vec = (0..nested_array.len())
.map(|i| ScalarValue::try_from_array(&nested_array, i))
.collect::<Result<Vec<_>>>()?;
let value = Some(scalar_vec);
Ok(ScalarValue::new_list(
value,
nested_type.data_type().clone(),
))
}
DataType::Date32 => {
typed_cast_to_scalar!(array, row_index, Date32Array, Date32)
}
DataType::Date64 => {
typed_cast_to_scalar!(array, row_index, Date64Array, Date64)
}
DataType::Time32(TimeUnit::Second) => {
typed_cast_to_scalar!(array, row_index, Time32SecondArray, Time32Second)
}
DataType::Time32(TimeUnit::Millisecond) => typed_cast_to_scalar!(
array,
row_index,
Time32MillisecondArray,
Time32Millisecond
),
DataType::Time64(TimeUnit::Microsecond) => typed_cast_to_scalar!(
array,
row_index,
Time64MicrosecondArray,
Time64Microsecond
),
DataType::Time64(TimeUnit::Nanosecond) => typed_cast_to_scalar!(
array,
row_index,
Time64NanosecondArray,
Time64Nanosecond
),
DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz_to_scalar!(
array,
row_index,
TimestampSecondArray,
TimestampSecond,
tz_opt
),
DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
typed_cast_tz_to_scalar!(
array,
row_index,
TimestampMillisecondArray,
TimestampMillisecond,
tz_opt
)
}
DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
typed_cast_tz_to_scalar!(
array,
row_index,
TimestampMicrosecondArray,
TimestampMicrosecond,
tz_opt
)
}
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
typed_cast_tz_to_scalar!(
array,
row_index,
TimestampNanosecondArray,
TimestampNanosecond,
tz_opt
)
}
DataType::Dictionary(key_type, _) => {
let (values_array, values_index) = match key_type.as_ref() {
DataType::Int8 => get_dict_value::<Int8Type>(array, row_index),
DataType::Int16 => get_dict_value::<Int16Type>(array, row_index),
DataType::Int32 => get_dict_value::<Int32Type>(array, row_index),
DataType::Int64 => get_dict_value::<Int64Type>(array, row_index),
DataType::UInt8 => get_dict_value::<UInt8Type>(array, row_index),
DataType::UInt16 => get_dict_value::<UInt16Type>(array, row_index),
DataType::UInt32 => get_dict_value::<UInt32Type>(array, row_index),
DataType::UInt64 => get_dict_value::<UInt64Type>(array, row_index),
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
};
// look up the index in the values dictionary
match values_index {
Some(values_index) => {
let value = ScalarValue::try_from_array(values_array, values_index)?;
Ok(ScalarValue::Dictionary(key_type.clone(), Box::new(value)))
}
// else entry was null, so return null
None => Ok(ScalarValue::Null),
}
}
DataType::Struct(fields) => {
let array = as_struct_array(array)?;
let mut field_values: Vec<ScalarValue> = Vec::new();
for col_index in 0..array.num_columns() {
let col_array = array.column(col_index);
let col_scalar = ScalarValue::try_from_array(col_array, row_index)?;
field_values.push(col_scalar);
}
Ok(ScalarValue::Struct(Some(field_values), fields.clone()))
}
DataType::FixedSizeList(nested_type, _len) => {
let list_array = as_fixed_size_list_array(array)?;
match list_array.is_null(row_index) {
true => Ok(ScalarValue::Null),
false => {
let nested_array = list_array.value(row_index);
let scalar_vec = (0..nested_array.len())
.map(|i| ScalarValue::try_from_array(&nested_array, i))
.collect::<Result<Vec<_>>>()?;
Ok(ScalarValue::new_list(
Some(scalar_vec),
nested_type.data_type().clone(),
))
}
}
}
DataType::FixedSizeBinary(_) => {
let array = as_fixed_size_binary_array(array)?;
let size = match array.data_type() {
DataType::FixedSizeBinary(size) => *size,
_ => unreachable!(),
};
Ok(ScalarValue::FixedSizeBinary(
size,
Some(array.value(row_index).into()),
))
}
DataType::Interval(IntervalUnit::DayTime) => {
typed_cast_to_scalar!(array, row_index, IntervalDayTimeArray, IntervalDayTime)
}
DataType::Interval(IntervalUnit::YearMonth) => typed_cast_to_scalar!(
array,
row_index,
IntervalYearMonthArray,
IntervalYearMonth
),
DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast_to_scalar!(
array,
row_index,
IntervalMonthDayNanoArray,
IntervalMonthDayNano
),
other => Err(DataFusionError::NotImplemented(format!(
"GroupedHashAggregate: can't create a scalar from array of type \"{other:?}\""
))),
let mut res = ScalarValue::try_from_array(array, row_index)?;
if res.is_null() {

This comment has been minimized.

Copy link
@Dandandan

Dandandan Apr 18, 2023

Contributor

Why not return the res directly here?

res = ScalarValue::Null;
}
Ok(res)
}

0 comments on commit 4a35fad

Please sign in to comment.