Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply type_union_resolution to array and values #12753

Merged
merged 4 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,10 +471,16 @@ fn type_union_resolution_coercion(
let new_value_type = type_union_resolution_coercion(value_type, other_type);
new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t)))
}
(DataType::List(lhs), DataType::List(rhs)) => {
let new_item_type =
type_union_resolution_coercion(lhs.data_type(), rhs.data_type());
new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true))))
}
_ => {
// numeric coercion is the same as comparison coercion, both find the narrowest type
// that can accommodate both types
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| numeric_string_coercion(lhs_type, rhs_type))
}
Expand Down Expand Up @@ -507,22 +513,6 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
.or_else(|| struct_coercion(lhs_type, rhs_type))
}

/// Coerce `lhs_type` and `rhs_type` to a common type for `VALUES` expression
///
/// For example `VALUES (1, 2), (3.0, 4.0)` where the first row is `Int32` and
/// the second row is `Float64` will coerce to `Float64`
///
pub fn values_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if lhs_type == rhs_type {
// same type => equality is possible
return Some(lhs_type.clone());
}
binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| binary_coercion(lhs_type, rhs_type))
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one is numeric and one is `Utf8`/`LargeUtf8`.
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
Expand Down
5 changes: 3 additions & 2 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use crate::logical_plan::{
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
use crate::type_coercion::binary::values_coercion;
use crate::utils::{
can_hash, columnize_expr, compare_sort_expr, expr_to_columns,
find_valid_equijoin_key_pair, group_window_expr_by_sort_keys,
Expand All @@ -53,6 +52,7 @@ use datafusion_common::{
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
TableReference, ToDFSchema, UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

use super::dml::InsertOp;
use super::plan::{ColumnUnnestList, ColumnUnnestType};
Expand Down Expand Up @@ -209,7 +209,8 @@ impl LogicalPlanBuilder {
}
if let Some(prev_type) = common_type {
// get common type of each column values.
let Some(new_type) = values_coercion(&data_type, &prev_type) else {
let data_types = vec![prev_type.clone(), data_type.clone()];
let Some(new_type) = type_union_resolution(&data_types) else {
return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}");
};
common_type = Some(new_type);
Expand Down
23 changes: 15 additions & 8 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ pub fn data_types(
try_coerce_types(valid_types, current_types, &signature.type_signature)
}

fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
if let TypeSignature::OneOf(signatures) = type_signature {
return signatures.iter().all(is_well_supported_signature);
}

matches!(
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_)
| TypeSignature::Any(_)
)
}

fn try_coerce_types(
valid_types: Vec<Vec<DataType>>,
current_types: &[DataType],
Expand All @@ -175,14 +189,7 @@ fn try_coerce_types(
let mut valid_types = valid_types;

// Well-supported signature that returns exact valid types.
if !valid_types.is_empty()
&& matches!(
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_)
)
{
if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
// exact valid types
assert_eq!(valid_types.len(), 1);
let valid_types = valid_types.swap_remove(0);
Expand Down
54 changes: 23 additions & 31 deletions datafusion/functions-nested/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! [`ScalarUDFImpl`] definitions for `make_array` function.

use std::vec;
use std::{any::Any, sync::Arc};

use arrow::array::{ArrayData, Capacities, MutableArrayData};
Expand All @@ -26,9 +27,8 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
use datafusion_common::internal_err;
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::TypeSignature;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

Expand Down Expand Up @@ -82,19 +82,12 @@ impl ScalarUDFImpl for MakeArray {
match arg_types.len() {
0 => Ok(empty_array_type()),
_ => {
let mut expr_type = DataType::Null;
for arg_type in arg_types {
if !arg_type.equals_datatype(&DataType::Null) {
expr_type = arg_type.clone();
break;
}
}

if expr_type.is_null() {
expr_type = DataType::Int64;
}

Ok(List(Arc::new(Field::new("item", expr_type, true))))
// At this point, all the type in array should be coerced to the same one
Ok(List(Arc::new(Field::new(
"item",
arg_types[0].to_owned(),
true,
))))
}
}
}
Expand All @@ -112,22 +105,21 @@ impl ScalarUDFImpl for MakeArray {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let new_type = arg_types.iter().skip(1).try_fold(
arg_types.first().unwrap().clone(),
|acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
// types that can be coerced to both `acc` and `x` for comparison purpose.
// See `maybe_data_types` for the actual coercion.
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
},
)?;
Ok(vec![new_type; arg_types.len()])
if let Some(new_type) = type_union_resolution(arg_types) {
if let DataType::FixedSizeList(field, _) = new_type {
Ok(vec![DataType::List(field); arg_types.len()])
} else if new_type.is_null() {
Ok(vec![DataType::Int64; arg_types.len()])
} else {
Ok(vec![new_type; arg_types.len()])
}
} else {
plan_err!(
"Fail to find the valid type between {:?} for {}",
arg_types,
self.name()
)
}
}
}

Expand Down
25 changes: 0 additions & 25 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
self.schema,
&func,
)?;
let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &func)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(func, new_expr),
)))
Expand Down Expand Up @@ -756,30 +755,6 @@ fn coerce_arguments_for_signature_with_aggregate_udf(
.collect()
}

fn coerce_arguments_for_fun(
expressions: Vec<Expr>,
schema: &DFSchema,
fun: &Arc<ScalarUDF>,
) -> Result<Vec<Expr>> {
// Cast Fixedsizelist to List for array functions
if fun.name() == "make_array" {
expressions
.into_iter()
.map(|expr| {
let data_type = expr.get_type(schema).unwrap();
if let DataType::FixedSizeList(field, _) = data_type {
let to_type = DataType::List(Arc::clone(&field));
expr.cast_to(&to_type, schema)
} else {
Ok(expr)
}
})
.collect()
} else {
Ok(expressions)
}
}

fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
// Given expressions like:
//
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6595,7 +6595,7 @@ select make_array(1, 2.0, null, 3)
query ?
select make_array(1.0, '2', null)
----
[1.0, 2, ]
[1.0, 2.0, ]

### FixedSizeListArray

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/errors.slt
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ from aggregate_test_100
order by c9


statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type
create table foo as values (1), ('foo');
32 changes: 20 additions & 12 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,17 @@ SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']);
{[1, 2]: [a, b], [3, 4]: [b]}

query ?
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);
SELECT MAKE_MAP('POST', 41, 'HEAD', 53, 'PATCH', 30);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems mixed type is not allowed in duckdb too

----
{POST: 41, HEAD: ab, PATCH: 30}
{POST: 41, HEAD: 53, PATCH: 30}

query error DataFusion error: Arrow error: Cast error: Cannot cast string 'ab' to value of Int64 type
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);

# Map keys can not be NULL
query error
SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30);

query ?
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);
----
{POST: 41, HEAD: ab, PATCH: 30}

query ?
SELECT MAKE_MAP()
----
Expand Down Expand Up @@ -517,9 +516,12 @@ query error
SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }[NULL];

query ?
SELECT MAP { 'a': 1, 2: 3 };
SELECT MAP { 'a': 1, 'b': 3 };
----
{a: 1, 2: 3}
{a: 1, b: 3}

query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
SELECT MAP { 'a': 1, 2: 3 };

# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
# query ?
Expand Down Expand Up @@ -610,9 +612,12 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7)
# Tests for map_keys

query ?
SELECT map_keys(MAP { 'a': 1, 2: 3 });
SELECT map_keys(MAP { 'a': 1, 'b': 3 });
----
[a, 2]
[a, b]

query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
SELECT map_keys(MAP { 'a': 1, 2: 3 });

query ?
SELECT map_keys(MAP {'a':1, 'b':2, 'c':3 }) FROM t;
Expand Down Expand Up @@ -657,8 +662,11 @@ SELECT map_keys(column1) from map_array_table_1;

# Tests for map_values

query ?
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
SELECT map_values(MAP { 'a': 1, 2: 3 });

query ?
SELECT map_values(MAP { 'a': 1, 'b': 3 });
----
[1, 3]

Expand Down
10 changes: 8 additions & 2 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -348,17 +348,23 @@ VALUES (1),()
statement error DataFusion error: Error during planning: Inconsistent data length across values list: got 2 values in row 1 but expected 1
VALUES (1),(1,2)

statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 0
query I
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Looks like an improvement to me

VALUES (1),('2')
----
1
2

query R
VALUES (1),(2.0)
----
1
2

statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 1
query II
VALUES (1,2), (1,'2')
----
1 2
1 2

query IT
VALUES (1,'a'),(NULL,'b'),(3,'c')
Expand Down