From cafff6108e53697b120f5d1c156dfa6d7cbc673b Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Thu, 30 Mar 2023 11:10:20 +0300 Subject: [PATCH] case expr type coercion --- datafusion/expr/src/expr_schema.rs | 50 ++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index fafda79a6f61..a48bfd2da3cf 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,6 +21,7 @@ use crate::expr::{ }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::binary_operator_data_type; +use crate::type_coercion::other::get_coerce_type_for_case_when; use crate::{aggregate_function, function, window_function}; use arrow::compute::can_cast_types; use arrow::datatypes::DataType; @@ -68,7 +69,25 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), - Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), + Expr::Case(case) => { + let then_types = case + .when_then_expr + .iter() + .map(|when_then| when_then.1.get_type(schema)) + .collect::>>()?; + let else_type = case + .else_expr + .as_ref() + .map_or(Ok(None), |e| e.get_type(schema).map(Some))?; + + get_coerce_type_for_case_when(&then_types, else_type.as_ref()).ok_or_else( + || { + DataFusionError::Plan( + "Unable to coerce return type of CASE expression".to_owned(), + ) + }, + ) + } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::ScalarUDF { fun, args } => { @@ -289,9 +308,9 @@ impl ExprSchemable for Expr { #[cfg(test)] mod tests { use super::*; - use crate::{col, lit}; + use crate::{col, lit, when}; use arrow::datatypes::DataType; - use datafusion_common::Column; + use datafusion_common::{Column, ScalarValue}; #[test] fn expr_schema_nullability() { @@ -312,6 +331,31 @@ mod tests { ); } + #[test] + fn case_expr_data_type() { + let expr_else_null = when(col("foo").eq(lit(1)), lit(1_i32)) + .otherwise(lit(ScalarValue::Null)) + .unwrap(); + + assert_eq!( + DataType::Int32, + expr_else_null + .get_type(&MockExprSchema::new().with_data_type(DataType::Int64)) + .unwrap() + ); + + let expr_then_null = when(col("foo").eq(lit(1)), lit(ScalarValue::Null)) + .otherwise(lit(1_i32)) + .unwrap(); + + assert_eq!( + DataType::Int32, + expr_then_null + .get_type(&MockExprSchema::new().with_data_type(DataType::Int64)) + .unwrap() + ); + } + struct MockExprSchema { nullable: bool, data_type: DataType,