From 8f0c051f6a99c833cc572d250e5e28b57dbad947 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Sat, 1 Apr 2023 12:35:52 +1100 Subject: [PATCH 1/5] Coerce case expression and when to common type --- .../tests/sqllogictests/test_files/select.slt | 6 ++ datafusion/optimizer/src/type_coercion.rs | 91 ++++++++++++++----- 2 files changed, 74 insertions(+), 23 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/select.slt b/datafusion/core/tests/sqllogictests/test_files/select.slt index 101d031885a8..464de209f6cf 100644 --- a/datafusion/core/tests/sqllogictests/test_files/select.slt +++ b/datafusion/core/tests/sqllogictests/test_files/select.slt @@ -214,3 +214,9 @@ select * from (select 1 a union all select 2) b order by a limit null; query I select * from (select 1 a union all select 2) b order by a limit 0; ---- + +# select case when type coercion +query I +select CASE 10.5 WHEN 0 THEN 1 ELSE 10 END as col; +---- +10 diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 0be9c89b6ccb..252b433968c8 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -330,8 +330,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - // all the result of then and else should be convert to a common data type, - // if they can be coercible to a common data type, return error. + // all the result of then and else should be convert to a common data type + // if case.expr is provided, this should be converted to common data type for when + // if they can't be coercible to a common data type, then return error. + + // prepare types + let case_type = match &case.expr { + None => Ok(None), + Some(expr) => expr.get_type(&self.schema).map(Some), + }?; let then_types = case .when_then_expr .iter() @@ -341,29 +348,67 @@ impl TreeNodeRewriter for TypeCoercionRewriter { None => Ok(None), Some(expr) => expr.get_type(&self.schema).map(Some), }?; - let case_when_coerce_type = - get_coerce_type_for_case_when(&then_types, else_type.as_ref()); - match case_when_coerce_type { - None => Err(DataFusionError::Internal(format!( - "Failed to coerce then ({then_types:?}) and else ({else_type:?}) to common types in CASE WHEN expression" - ))), - Some(data_type) => { - let left = case.when_then_expr - .into_iter() - .map(|(when, then)| { - let then = then.cast_to(&data_type, &self.schema)?; - Ok((when, Box::new(then))) - }) + + // find common coercible types + let case_when_coerce_type = match case_type { + None => None, + Some(case_type) => { + let when_types = case + .when_then_expr + .iter() + .map(|when_then| when_then.0.get_type(&self.schema)) .collect::>>()?; - let right = match &case.else_expr { - None => None, - Some(expr) => { - Some(Box::new(expr.clone().cast_to(&data_type, &self.schema)?)) - } - }; - Ok(Expr::Case(Case::new(case.expr,left,right))) + let coerced_type = + get_coerce_type_for_case_when(&when_types, Some(&case_type)); + let coerced_type = coerced_type.ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ + to common types in CASE WHEN expression" + )) + })?; + Some(coerced_type) } - } + }; + let then_else_coerce_type = + get_coerce_type_for_case_when(&then_types, else_type.as_ref()) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ + to common types in CASE WHEN expression" + )) + })?; + + // do cast if found common coercible types + let case_expr = match case.expr.zip(case_when_coerce_type.as_ref()) { + None => None, + Some((case_expr, coercible_type)) => { + let coerced_expr = + case_expr.cast_to(coercible_type, &self.schema)?; + Some(Box::new(coerced_expr)) + } + }; + let when_then = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when = match &case_when_coerce_type { + Some(case_when_coerce_type) => Box::new( + when.cast_to(case_when_coerce_type, &self.schema)?, + ), + None => when, + }; + let then = + Box::new(then.cast_to(&then_else_coerce_type, &self.schema)?); + Ok((when, then)) + }) + .collect::>>()?; + let else_expr = match &case.else_expr { + None => None, + Some(expr) => Some(Box::new( + expr.clone().cast_to(&then_else_coerce_type, &self.schema)?, + )), + }; + Ok(Expr::Case(Case::new(case_expr, when_then, else_expr))) } Expr::ScalarUDF { fun, args } => { let new_expr = coerce_arguments_for_signature( From 0bd24a15f8c5b9fd7da23d5b01ac2c9910a120a9 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Sat, 1 Apr 2023 19:28:07 +1100 Subject: [PATCH 2/5] Update comments --- datafusion/optimizer/src/type_coercion.rs | 26 ++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 252b433968c8..4b48dc13066d 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -330,9 +330,29 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - // all the result of then and else should be convert to a common data type - // if case.expr is provided, this should be converted to common data type for when - // if they can't be coercible to a common data type, then return error. + // Given expressions like: + // + // CASE a1 + // WHEN a2 THEN b1 + // WHEN a3 THEN b2 + // ELSE b3 + // END + // + // or: + // + // CASE + // WHEN x1 THEN b1 + // WHEN x2 THEN b2 + // ELSE b3 + // END + // + // Then all aN (a1, a2, a3) must be converted to a common data type in the first example + // (case-when expression coercion) + // + // And all bN (b1, b2, b3) must be converted to a common data type in both examples + // (then-else expression coercion) + // + // If either fail to find a common data type, will return error // prepare types let case_type = match &case.expr { From 964c1468558cab20c86efeec88c8e0cbe144aed7 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Sat, 1 Apr 2023 23:13:13 +1100 Subject: [PATCH 3/5] Refactoring --- datafusion/optimizer/src/type_coercion.rs | 65 +++++++++++------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 4b48dc13066d..b907ed140ae7 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -355,40 +355,40 @@ impl TreeNodeRewriter for TypeCoercionRewriter { // If either fail to find a common data type, will return error // prepare types - let case_type = match &case.expr { - None => Ok(None), - Some(expr) => expr.get_type(&self.schema).map(Some), - }?; + let case_type = case + .expr + .as_ref() + .map(|expr| expr.get_type(&self.schema)) + .transpose()?; let then_types = case .when_then_expr .iter() - .map(|when_then| when_then.1.get_type(&self.schema)) + .map(|(_when, then)| then.get_type(&self.schema)) .collect::>>()?; - let else_type = match &case.else_expr { - None => Ok(None), - Some(expr) => expr.get_type(&self.schema).map(Some), - }?; + let else_type = case + .else_expr + .as_ref() + .map(|expr| expr.get_type(&self.schema)) + .transpose()?; // find common coercible types - let case_when_coerce_type = match case_type { - None => None, - Some(case_type) => { + let case_when_coerce_type = case_type.as_ref() + .map(|case_type| { let when_types = case .when_then_expr .iter() - .map(|when_then| when_then.0.get_type(&self.schema)) + .map(|(when, _then)| when.get_type(&self.schema)) .collect::>>()?; let coerced_type = - get_coerce_type_for_case_when(&when_types, Some(&case_type)); - let coerced_type = coerced_type.ok_or_else(|| { + get_coerce_type_for_case_when(&when_types, Some(case_type)); + coerced_type.ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ to common types in CASE WHEN expression" )) - })?; - Some(coerced_type) - } - }; + }) + }) + .transpose()?; let then_else_coerce_type = get_coerce_type_for_case_when(&then_types, else_type.as_ref()) .ok_or_else(|| { @@ -399,14 +399,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { })?; // do cast if found common coercible types - let case_expr = match case.expr.zip(case_when_coerce_type.as_ref()) { - None => None, - Some((case_expr, coercible_type)) => { - let coerced_expr = - case_expr.cast_to(coercible_type, &self.schema)?; - Some(Box::new(coerced_expr)) - } - }; + let case_expr = case + .expr + .zip(case_when_coerce_type.as_ref()) + .map(|(case_expr, coercible_type)| { + case_expr.cast_to(coercible_type, &self.schema) + }) + .transpose()? + .map(Box::new); let when_then = case .when_then_expr .into_iter() @@ -422,12 +422,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { Ok((when, then)) }) .collect::>>()?; - let else_expr = match &case.else_expr { - None => None, - Some(expr) => Some(Box::new( - expr.clone().cast_to(&then_else_coerce_type, &self.schema)?, - )), - }; + let else_expr = case + .else_expr + .map(|expr| expr.cast_to(&then_else_coerce_type, &self.schema)) + .transpose()? + .map(Box::new); Ok(Expr::Case(Case::new(case_expr, when_then, else_expr))) } Expr::ScalarUDF { fun, args } => { From 756068d2ea40ea352132abfd358520826d874ec2 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Sun, 2 Apr 2023 10:00:43 +1000 Subject: [PATCH 4/5] Cast case when exprs to boolean if no case expr --- .../tests/sqllogictests/test_files/select.slt | 12 +++++-- datafusion/optimizer/src/type_coercion.rs | 31 +++++++++++++------ .../physical-expr/src/expressions/case.rs | 8 +++-- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/select.slt b/datafusion/core/tests/sqllogictests/test_files/select.slt index 464de209f6cf..3de957101fee 100644 --- a/datafusion/core/tests/sqllogictests/test_files/select.slt +++ b/datafusion/core/tests/sqllogictests/test_files/select.slt @@ -215,8 +215,14 @@ query I select * from (select 1 a union all select 2) b order by a limit 0; ---- -# select case when type coercion +# select case when type coercion with case expression query I -select CASE 10.5 WHEN 0 THEN 1 ELSE 10 END as col; +select CASE 10.5 WHEN 0 THEN 1 ELSE 2 END as col; ---- -10 +2 + +# select case when type coercion without case expression +query I +select CASE WHEN 0 THEN 1 ELSE 2 END as col; +---- +2 diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index b907ed140ae7..241ba3ec2820 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -349,10 +349,16 @@ impl TreeNodeRewriter for TypeCoercionRewriter { // Then all aN (a1, a2, a3) must be converted to a common data type in the first example // (case-when expression coercion) // + // All xN (x1, x2) must be converted to a boolean data type in the second example + // (when-boolean expression coercion) + // // And all bN (b1, b2, b3) must be converted to a common data type in both examples // (then-else expression coercion) // - // If either fail to find a common data type, will return error + // If any fail to find and cast to a common/specific data type, will return error + // + // Note that case-when and when-boolean expression coercions are mutually exclusive + // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur // prepare types let case_type = case @@ -411,15 +417,20 @@ impl TreeNodeRewriter for TypeCoercionRewriter { .when_then_expr .into_iter() .map(|(when, then)| { - let when = match &case_when_coerce_type { - Some(case_when_coerce_type) => Box::new( - when.cast_to(case_when_coerce_type, &self.schema)?, - ), - None => when, - }; - let then = - Box::new(then.cast_to(&then_else_coerce_type, &self.schema)?); - Ok((when, then)) + let when_type = + case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); + let when = + when.cast_to(when_type, &self.schema).map_err(|e| { + DataFusionError::Context( + format!( + "WHEN expressions in CASE couldn't be \ + converted to common type ({when_type})" + ), + Box::new(e), + ) + })?; + let then = then.cast_to(&then_else_coerce_type, &self.schema)?; + Ok((Box::new(when), Box::new(then))) }) .collect::>>()?; let else_expr = case diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b4a7b1c59b47..2d97d57324fe 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -195,8 +195,12 @@ impl CaseExpr { _ => when_value, }; let when_value = when_value.into_array(batch.num_rows()); - let when_value = as_boolean_array(&when_value) - .expect("WHEN expression did not return a BooleanArray"); + let when_value = as_boolean_array(&when_value).map_err(|e| { + DataFusionError::Context( + "WHEN expression did not return a BooleanArray".to_string(), + Box::new(e), + ) + })?; let then_value = self.when_then_expr[i] .1 From 37d1d5fcfb4574a32a098aff52937e47d2475b01 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Mon, 3 Apr 2023 22:15:19 +1000 Subject: [PATCH 5/5] Adding unit tests --- .../tests/sqllogictests/test_files/select.slt | 8 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_schema.rs | 9 +- datafusion/expr/src/type_coercion/other.rs | 20 +- datafusion/optimizer/src/type_coercion.rs | 378 ++++++++++++------ 5 files changed, 286 insertions(+), 131 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/select.slt b/datafusion/core/tests/sqllogictests/test_files/select.slt index 3de957101fee..65cf24b6448c 100644 --- a/datafusion/core/tests/sqllogictests/test_files/select.slt +++ b/datafusion/core/tests/sqllogictests/test_files/select.slt @@ -217,12 +217,16 @@ select * from (select 1 a union all select 2) b order by a limit 0; # select case when type coercion with case expression query I -select CASE 10.5 WHEN 0 THEN 1 ELSE 2 END as col; +select CASE 10.5 WHEN 0 THEN 1 ELSE 2 END; ---- 2 # select case when type coercion without case expression query I -select CASE WHEN 0 THEN 1 ELSE 2 END as col; +select CASE + WHEN 10 = 5 THEN 1 + WHEN 'true' THEN 2 + ELSE 3 +END; ---- 2 diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2806683ab87b..13973bd34996 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -277,7 +277,7 @@ impl Display for BinaryExpr { } /// CASE expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Case { /// Optional base expression that can be compared to literal values in the "when" expressions pub expr: Option>, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index bfe464b232cd..479b74ea0852 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,7 +21,7 @@ use crate::expr::{ }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::binary_operator_data_type; -use crate::type_coercion::other::get_coerce_type_for_case_when; +use crate::type_coercion::other::get_coerce_type_for_case_expression; use crate::{aggregate_function, function, window_function}; use arrow::compute::can_cast_types; use arrow::datatypes::DataType; @@ -81,13 +81,12 @@ impl ExprSchemable for Expr { None => Ok(None), Some(expr) => expr.get_type(schema).map(Some), }?; - get_coerce_type_for_case_when(&then_types, else_type.as_ref()).ok_or_else( - || { + get_coerce_type_for_case_expression(&then_types, else_type.as_ref()) + .ok_or_else(|| { DataFusionError::Internal(String::from( "Cannot infer type for CASE statement", )) - }, - ) + }) } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), diff --git a/datafusion/expr/src/type_coercion/other.rs b/datafusion/expr/src/type_coercion/other.rs index 6ff1300f64e2..c53054e82112 100644 --- a/datafusion/expr/src/type_coercion/other.rs +++ b/datafusion/expr/src/type_coercion/other.rs @@ -34,20 +34,20 @@ pub fn get_coerce_type_for_list( }) } -/// Find a common coerceable type for all `then_types` as well -/// and the `else_type`, if specified. -/// Returns the common data type for `then_types` and `else_type` -pub fn get_coerce_type_for_case_when( - then_types: &[DataType], - else_type: Option<&DataType>, +/// Find a common coerceable type for all `when_or_then_types` as well +/// and the `case_or_else_type`, if specified. +/// Returns the common data type for `when_or_then_types` and `case_or_else_type` +pub fn get_coerce_type_for_case_expression( + when_or_then_types: &[DataType], + case_or_else_type: Option<&DataType>, ) -> Option { - let else_type = match else_type { - None => then_types[0].clone(), + let case_or_else_type = match case_or_else_type { + None => when_or_then_types[0].clone(), Some(data_type) => data_type.clone(), }; - then_types + when_or_then_types .iter() - .fold(Some(else_type), |left, right_type| match left { + .fold(Some(case_or_else_type), |left, right_type| match left { // failed to find a valid coercion in a previous iteration None => None, // TODO: now just use the `equal` coercion rule for case when. If find the issue, and diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 241ba3ec2820..35730c0f427c 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -32,7 +32,7 @@ use datafusion_expr::type_coercion::binary::{ }; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ - get_coerce_type_for_case_when, get_coerce_type_for_list, + get_coerce_type_for_case_expression, get_coerce_type_for_list, }; use datafusion_expr::type_coercion::{ is_date, is_numeric, is_timestamp, is_utf8_or_large_utf8, @@ -330,115 +330,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - // Given expressions like: - // - // CASE a1 - // WHEN a2 THEN b1 - // WHEN a3 THEN b2 - // ELSE b3 - // END - // - // or: - // - // CASE - // WHEN x1 THEN b1 - // WHEN x2 THEN b2 - // ELSE b3 - // END - // - // Then all aN (a1, a2, a3) must be converted to a common data type in the first example - // (case-when expression coercion) - // - // All xN (x1, x2) must be converted to a boolean data type in the second example - // (when-boolean expression coercion) - // - // And all bN (b1, b2, b3) must be converted to a common data type in both examples - // (then-else expression coercion) - // - // If any fail to find and cast to a common/specific data type, will return error - // - // Note that case-when and when-boolean expression coercions are mutually exclusive - // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur - - // prepare types - let case_type = case - .expr - .as_ref() - .map(|expr| expr.get_type(&self.schema)) - .transpose()?; - let then_types = case - .when_then_expr - .iter() - .map(|(_when, then)| then.get_type(&self.schema)) - .collect::>>()?; - let else_type = case - .else_expr - .as_ref() - .map(|expr| expr.get_type(&self.schema)) - .transpose()?; - - // find common coercible types - let case_when_coerce_type = case_type.as_ref() - .map(|case_type| { - let when_types = case - .when_then_expr - .iter() - .map(|(when, _then)| when.get_type(&self.schema)) - .collect::>>()?; - let coerced_type = - get_coerce_type_for_case_when(&when_types, Some(case_type)); - coerced_type.ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ - to common types in CASE WHEN expression" - )) - }) - }) - .transpose()?; - let then_else_coerce_type = - get_coerce_type_for_case_when(&then_types, else_type.as_ref()) - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ - to common types in CASE WHEN expression" - )) - })?; - - // do cast if found common coercible types - let case_expr = case - .expr - .zip(case_when_coerce_type.as_ref()) - .map(|(case_expr, coercible_type)| { - case_expr.cast_to(coercible_type, &self.schema) - }) - .transpose()? - .map(Box::new); - let when_then = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - let when_type = - case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); - let when = - when.cast_to(when_type, &self.schema).map_err(|e| { - DataFusionError::Context( - format!( - "WHEN expressions in CASE couldn't be \ - converted to common type ({when_type})" - ), - Box::new(e), - ) - })?; - let then = then.cast_to(&then_else_coerce_type, &self.schema)?; - Ok((Box::new(when), Box::new(then))) - }) - .collect::>>()?; - let else_expr = case - .else_expr - .map(|expr| expr.cast_to(&then_else_coerce_type, &self.schema)) - .transpose()? - .map(Box::new); - Ok(Expr::Case(Case::new(case_expr, when_then, else_expr))) + let case = coerce_case_expression(case, &self.schema)?; + Ok(Expr::Case(case)) } Expr::ScalarUDF { fun, args } => { let new_expr = coerce_arguments_for_signature( @@ -713,19 +606,130 @@ fn coerce_agg_exprs_for_signature( .collect::>>() } +fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { + // Given expressions like: + // + // CASE a1 + // WHEN a2 THEN b1 + // WHEN a3 THEN b2 + // ELSE b3 + // END + // + // or: + // + // CASE + // WHEN x1 THEN b1 + // WHEN x2 THEN b2 + // ELSE b3 + // END + // + // Then all aN (a1, a2, a3) must be converted to a common data type in the first example + // (case-when expression coercion) + // + // All xN (x1, x2) must be converted to a boolean data type in the second example + // (when-boolean expression coercion) + // + // And all bN (b1, b2, b3) must be converted to a common data type in both examples + // (then-else expression coercion) + // + // If any fail to find and cast to a common/specific data type, will return error + // + // Note that case-when and when-boolean expression coercions are mutually exclusive + // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur + + // prepare types + let case_type = case + .expr + .as_ref() + .map(|expr| expr.get_type(&schema)) + .transpose()?; + let then_types = case + .when_then_expr + .iter() + .map(|(_when, then)| then.get_type(&schema)) + .collect::>>()?; + let else_type = case + .else_expr + .as_ref() + .map(|expr| expr.get_type(&schema)) + .transpose()?; + + // find common coercible types + let case_when_coerce_type = case_type + .as_ref() + .map(|case_type| { + let when_types = case + .when_then_expr + .iter() + .map(|(when, _then)| when.get_type(&schema)) + .collect::>>()?; + let coerced_type = + get_coerce_type_for_case_expression(&when_types, Some(case_type)); + coerced_type.ok_or_else(|| { + DataFusionError::Plan(format!( + "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ + to common types in CASE WHEN expression" + )) + }) + }) + .transpose()?; + let then_else_coerce_type = + get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else( + || { + DataFusionError::Plan(format!( + "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ + to common types in CASE WHEN expression" + )) + }, + )?; + + // do cast if found common coercible types + let case_expr = case + .expr + .zip(case_when_coerce_type.as_ref()) + .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, &schema)) + .transpose()? + .map(Box::new); + let when_then = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); + let when = when.cast_to(when_type, &schema).map_err(|e| { + DataFusionError::Context( + format!( + "WHEN expressions in CASE couldn't be \ + converted to common type ({when_type})" + ), + Box::new(e), + ) + })?; + let then = then.cast_to(&then_else_coerce_type, &schema)?; + Ok((Box::new(when), Box::new(then))) + }) + .collect::>>()?; + let else_expr = case + .else_expr + .map(|expr| expr.cast_to(&then_else_coerce_type, &schema)) + .transpose()? + .map(Box::new); + + Ok(Case::new(case_expr, when_then, else_expr)) +} + #[cfg(test)] mod test { use std::sync::Arc; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::TreeNode; - use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; + use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, Like}; use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, - BuiltinScalarFunction, ColumnarValue, StateTypeFunction, + BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, StateTypeFunction, }; use datafusion_expr::{ lit, @@ -738,6 +742,8 @@ mod test { use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter}; use crate::{OptimizerContext, OptimizerRule}; + use super::coerce_case_expression; + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { let rule = TypeCoercion::new(); let config = OptimizerContext::default(); @@ -1248,4 +1254,150 @@ mod test { assert_optimized_plan_eq(&plan, expected)?; Ok(()) } + + fn cast_if_not_same_type( + expr: Box, + data_type: &DataType, + schema: &DFSchemaRef, + ) -> Box { + if &expr.get_type(schema).unwrap() != data_type { + Box::new(cast(*expr, data_type.clone())) + } else { + expr + } + } + + fn cast_helper( + case: Case, + case_when_type: DataType, + then_else_type: DataType, + schema: &DFSchemaRef, + ) -> Case { + let expr = case + .expr + .map(|e| cast_if_not_same_type(e, &case_when_type, schema)); + let when_then_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + ( + cast_if_not_same_type(when, &case_when_type, schema), + cast_if_not_same_type(then, &then_else_type, schema), + ) + }) + .collect::>(); + let else_expr = case + .else_expr + .map(|e| cast_if_not_same_type(e, &then_else_type, schema)); + + Case { + expr, + when_then_expr, + else_expr, + } + } + + #[test] + fn test_case_expression_coercion() -> Result<()> { + let schema = Arc::new(DFSchema::new_with_metadata( + vec![ + DFField::new_unqualified("boolean", DataType::Boolean, true), + DFField::new_unqualified("integer", DataType::Int32, true), + DFField::new_unqualified("float", DataType::Float32, true), + DFField::new_unqualified( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + DFField::new_unqualified("date", DataType::Date32, true), + DFField::new_unqualified( + "interval", + DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano), + true, + ), + DFField::new_unqualified("binary", DataType::Binary, true), + DFField::new_unqualified("string", DataType::Utf8, true), + DFField::new_unqualified("decimal", DataType::Decimal128(10, 10), true), + ], + std::collections::HashMap::new(), + )?); + + let case = Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("boolean")), Box::new(col("integer"))), + (Box::new(col("integer")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("string"))), + ], + else_expr: None, + }; + let case_when_common_type = DataType::Boolean; + let then_else_common_type = DataType::Utf8; + let expected = cast_helper( + case.clone(), + case_when_common_type, + then_else_common_type, + &schema, + ); + let actual = coerce_case_expression(case, &schema)?; + assert_eq!(expected, actual); + + let case = Case { + expr: Some(Box::new(col("string"))), + when_then_expr: vec![ + (Box::new(col("float")), Box::new(col("integer"))), + (Box::new(col("integer")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("string"))), + ], + else_expr: Some(Box::new(col("string"))), + }; + let case_when_common_type = DataType::Utf8; + let then_else_common_type = DataType::Utf8; + let expected = cast_helper( + case.clone(), + case_when_common_type, + then_else_common_type, + &schema, + ); + let actual = coerce_case_expression(case, &schema)?; + assert_eq!(expected, actual); + + let case = Case { + expr: Some(Box::new(col("interval"))), + when_then_expr: vec![ + (Box::new(col("float")), Box::new(col("integer"))), + (Box::new(col("binary")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("string"))), + ], + else_expr: Some(Box::new(col("string"))), + }; + let err = coerce_case_expression(case, &schema).unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: \ + Failed to coerce case (Interval(MonthDayNano)) and \ + when ([Float32, Binary, Utf8]) to common types in \ + CASE WHEN expression" + ); + + let case = Case { + expr: Some(Box::new(col("string"))), + when_then_expr: vec![ + (Box::new(col("float")), Box::new(col("date"))), + (Box::new(col("string")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("binary"))), + ], + else_expr: Some(Box::new(col("timestamp"))), + }; + let err = coerce_case_expression(case, &schema).unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: \ + Failed to coerce then ([Date32, Float32, Binary]) and \ + else (Some(Timestamp(Nanosecond, None))) to common types \ + in CASE WHEN expression" + ); + + Ok(()) + } }