From e9209e82d32ccd48125c629d8173dea3d710b09e Mon Sep 17 00:00:00 2001 From: Chih Wang Date: Sun, 10 Mar 2024 20:25:10 +0800 Subject: [PATCH 1/3] Port tan function --- datafusion/expr/src/built_in_function.rs | 6 - datafusion/expr/src/expr_fn.rs | 2 - datafusion/functions/src/math/mod.rs | 5 +- datafusion/functions/src/math/tan.rs | 110 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 1 - .../physical-expr/src/math_expressions.rs | 1 - datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 6 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - 11 files changed, 118 insertions(+), 23 deletions(-) create mode 100644 datafusion/functions/src/math/tan.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 991963a1bcca..0d857d97f858 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -102,8 +102,6 @@ pub enum BuiltinScalarFunction { Sinh, /// sqrt Sqrt, - /// tan - Tan, /// tanh Tanh, /// trunc @@ -342,7 +340,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sqrt => Volatility::Immutable, BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, - BuiltinScalarFunction::Tan => Volatility::Immutable, BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, @@ -741,7 +738,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tan | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Trunc | BuiltinScalarFunction::Cot => match input_expr_types[0] { @@ -1051,7 +1047,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sin | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Tan | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Cot => { // math expressions expect 1 argument of type f64 or f32 @@ -1148,7 +1143,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sin => &["sin"], BuiltinScalarFunction::Sinh => &["sinh"], BuiltinScalarFunction::Sqrt => &["sqrt"], - BuiltinScalarFunction::Tan => &["tan"], BuiltinScalarFunction::Tanh => &["tanh"], BuiltinScalarFunction::Trunc => &["trunc"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1e934267e9b8..2edf96087af4 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -538,7 +538,6 @@ scalar_expr!(Sqrt, sqrt, num, "square root of a number"); scalar_expr!(Cbrt, cbrt, num, "cube root of a number"); scalar_expr!(Sin, sin, num, "sine"); scalar_expr!(Cos, cos, num, "cosine"); -scalar_expr!(Tan, tan, num, "tangent"); scalar_expr!(Cot, cot, num, "cotangent"); scalar_expr!(Sinh, sinh, num, "hyperbolic sine"); scalar_expr!(Cosh, cosh, num, "hyperbolic cosine"); @@ -1239,7 +1238,6 @@ mod test { test_unary_scalar_expr!(Cbrt, cbrt); test_unary_scalar_expr!(Sin, sin); test_unary_scalar_expr!(Cos, cos); - test_unary_scalar_expr!(Tan, tan); test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Sinh, sinh); test_unary_scalar_expr!(Cosh, cosh); diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 3741cc2802bb..310a83ee490a 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -21,12 +21,14 @@ mod abs; mod acos; mod asin; mod nans; +mod tan; // create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); make_udf_function!(acos::AcosFunc, ACOS, acos); make_udf_function!(asin::AsinFunc, ASIN, asin); +make_udf_function!(tan::TanFunc, TAN, tan); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( @@ -45,5 +47,6 @@ export_functions!( asin, num, "returns the arc sine or inverse sine of a number" - ) + ), + (tan, num, "returns the tangent of a number") ); diff --git a/datafusion/functions/src/math/tan.rs b/datafusion/functions/src/math/tan.rs new file mode 100644 index 000000000000..4822f3925e76 --- /dev/null +++ b/datafusion/functions/src/math/tan.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `tan()`. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow_array::{ArrayRef, Float32Array, Float64Array}; +use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::utils::generate_signature_error_msg; +use datafusion_expr::Volatility; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub struct TanFunc { + signature: Signature, +} + +impl TanFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64, DataType::Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TanFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "tan" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return Err(plan_datafusion_err!( + "{}", + generate_signature_error_msg( + self.name(), + self.signature().clone(), + arg_types, + ) + )); + } + + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float64 => Ok(DataType::Float64), + DataType::Float32 => Ok(DataType::Float32), + + // For other types (possible values null/int), use Float 64 + _ => Ok(DataType::Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float64Array, + Float64Array, + { f64::tan } + )), + DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float32Array, + Float32Array, + { f32::tan } + )), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 2e1d48eb7620..c6225220182b 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -282,7 +282,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Sinh => Arc::new(math_expressions::sinh), BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt), - BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh), BuiltinScalarFunction::Trunc => { Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index a8c115ba3a82..a677e126cb38 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -159,7 +159,6 @@ math_unary_function!("sqrt", sqrt); math_unary_function!("cbrt", cbrt); math_unary_function!("sin", sin); math_unary_function!("cos", cos); -math_unary_function!("tan", tan); math_unary_function!("sinh", sinh); math_unary_function!("cosh", cosh); math_unary_function!("tanh", tanh); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c068b253ce4b..66605cf79119 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -567,7 +567,7 @@ enum ScalarFunction { Signum = 15; Sin = 16; Sqrt = 17; - Tan = 18; + // Tan = 18; Trunc = 19; Array = 20; // RegexpMatch = 21; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7a366c08ad24..0742b9658c5a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22113,7 +22113,6 @@ impl serde::Serialize for ScalarFunction { Self::Signum => "Signum", Self::Sin => "Sin", Self::Sqrt => "Sqrt", - Self::Tan => "Tan", Self::Trunc => "Trunc", Self::Array => "Array", Self::BitLength => "BitLength", @@ -22232,7 +22231,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Signum", "Sin", "Sqrt", - "Tan", "Trunc", "Array", "BitLength", @@ -22380,7 +22378,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Signum" => Ok(ScalarFunction::Signum), "Sin" => Ok(ScalarFunction::Sin), "Sqrt" => Ok(ScalarFunction::Sqrt), - "Tan" => Ok(ScalarFunction::Tan), "Trunc" => Ok(ScalarFunction::Trunc), "Array" => Ok(ScalarFunction::Array), "BitLength" => Ok(ScalarFunction::BitLength), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 79decc125226..d930c0cac9e4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2639,7 +2639,7 @@ pub enum ScalarFunction { Signum = 15, Sin = 16, Sqrt = 17, - Tan = 18, + /// Tan = 18; Trunc = 19, Array = 20, /// RegexpMatch = 21; @@ -2783,7 +2783,6 @@ impl ScalarFunction { ScalarFunction::Signum => "Signum", ScalarFunction::Sin => "Sin", ScalarFunction::Sqrt => "Sqrt", - ScalarFunction::Tan => "Tan", ScalarFunction::Trunc => "Trunc", ScalarFunction::Array => "Array", ScalarFunction::BitLength => "BitLength", @@ -2896,7 +2895,6 @@ impl ScalarFunction { "Signum" => Some(Self::Signum), "Sin" => Some(Self::Sin), "Sqrt" => Some(Self::Sqrt), - "Tan" => Some(Self::Tan), "Trunc" => Some(Self::Trunc), "Array" => Some(Self::Array), "BitLength" => Some(Self::BitLength), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index cc5491b3f204..37a3847ddbbe 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -62,8 +62,8 @@ use datafusion_expr::{ lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians, random, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, string_to_array, strpos, - struct_fun, substr, substr_index, substring, tan, tanh, to_hex, translate, trim, - trunc, upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + struct_fun, substr, substr_index, substring, tanh, to_hex, translate, trim, trunc, + upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -448,7 +448,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Cbrt => Self::Cbrt, ScalarFunction::Sin => Self::Sin, ScalarFunction::Cos => Self::Cos, - ScalarFunction::Tan => Self::Tan, ScalarFunction::Cot => Self::Cot, ScalarFunction::Atan => Self::Atan, ScalarFunction::Sinh => Self::Sinh, @@ -1519,7 +1518,6 @@ pub fn parse_expr( ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Tan => Ok(tan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 77b576dcd3cd..fd2084e35ea4 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1426,7 +1426,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cbrt => Self::Cbrt, BuiltinScalarFunction::Sin => Self::Sin, BuiltinScalarFunction::Cos => Self::Cos, - BuiltinScalarFunction::Tan => Self::Tan, BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Sinh => Self::Sinh, BuiltinScalarFunction::Cosh => Self::Cosh, From b553e85eb2cd5e4915d14440d3aae05233e4f5ab Mon Sep 17 00:00:00 2001 From: Chih Wang Date: Sun, 10 Mar 2024 20:48:43 +0800 Subject: [PATCH 2/3] Port tanh function --- datafusion/expr/src/built_in_function.rs | 7 -- datafusion/expr/src/expr_fn.rs | 2 - datafusion/functions/src/math/mod.rs | 5 +- datafusion/functions/src/math/tanh.rs | 110 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 1 - .../physical-expr/src/math_expressions.rs | 1 - datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 6 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - 11 files changed, 118 insertions(+), 24 deletions(-) create mode 100644 datafusion/functions/src/math/tanh.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 0d857d97f858..d7b1113d4bf1 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -102,8 +102,6 @@ pub enum BuiltinScalarFunction { Sinh, /// sqrt Sqrt, - /// tanh - Tanh, /// trunc Trunc, /// cot @@ -340,7 +338,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sqrt => Volatility::Immutable, BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, - BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, BuiltinScalarFunction::ArraySort => Volatility::Immutable, @@ -738,7 +735,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Trunc | BuiltinScalarFunction::Cot => match input_expr_types[0] { Float32 => Ok(Float32), @@ -1047,7 +1043,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sin | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Cot => { // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we @@ -1097,7 +1092,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Sinh | BuiltinScalarFunction::Sqrt | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Trunc | BuiltinScalarFunction::Pi ) { @@ -1143,7 +1137,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sin => &["sin"], BuiltinScalarFunction::Sinh => &["sinh"], BuiltinScalarFunction::Sqrt => &["sqrt"], - BuiltinScalarFunction::Tanh => &["tanh"], BuiltinScalarFunction::Trunc => &["trunc"], // conditional functions diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 2edf96087af4..10c65b3ccea0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -541,7 +541,6 @@ scalar_expr!(Cos, cos, num, "cosine"); scalar_expr!(Cot, cot, num, "cotangent"); scalar_expr!(Sinh, sinh, num, "hyperbolic sine"); scalar_expr!(Cosh, cosh, num, "hyperbolic cosine"); -scalar_expr!(Tanh, tanh, num, "hyperbolic tangent"); scalar_expr!(Atan, atan, num, "inverse tangent"); scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine"); scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine"); @@ -1241,7 +1240,6 @@ mod test { test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Sinh, sinh); test_unary_scalar_expr!(Cosh, cosh); - test_unary_scalar_expr!(Tanh, tanh); test_unary_scalar_expr!(Atan, atan); test_unary_scalar_expr!(Asinh, asinh); test_unary_scalar_expr!(Acosh, acosh); diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 310a83ee490a..e7ede6043a59 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -22,6 +22,7 @@ mod acos; mod asin; mod nans; mod tan; +mod tanh; // create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); @@ -29,6 +30,7 @@ make_udf_function!(abs::AbsFunc, ABS, abs); make_udf_function!(acos::AcosFunc, ACOS, acos); make_udf_function!(asin::AsinFunc, ASIN, asin); make_udf_function!(tan::TanFunc, TAN, tan); +make_udf_function!(tanh::TanhFunc, TANH, tanh); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( @@ -48,5 +50,6 @@ export_functions!( num, "returns the arc sine or inverse sine of a number" ), - (tan, num, "returns the tangent of a number") + (tan, num, "returns the tangent of a number"), + (tanh, num, "returns the hyperbolic tangent of a number") ); diff --git a/datafusion/functions/src/math/tanh.rs b/datafusion/functions/src/math/tanh.rs new file mode 100644 index 000000000000..01629d68ca63 --- /dev/null +++ b/datafusion/functions/src/math/tanh.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `tanh()`. + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow_array::{ArrayRef, Float32Array, Float64Array}; +use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::utils::generate_signature_error_msg; +use datafusion_expr::Volatility; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub struct TanhFunc { + signature: Signature, +} + +impl TanhFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64, DataType::Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TanhFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "tanh" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return Err(plan_datafusion_err!( + "{}", + generate_signature_error_msg( + self.name(), + self.signature().clone(), + arg_types, + ) + )); + } + + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float64 => Ok(DataType::Float64), + DataType::Float32 => Ok(DataType::Float32), + + // For other types (possible values null/int), use Float 64 + _ => Ok(DataType::Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float64Array, + Float64Array, + { f64::tanh } + )), + DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + self.name(), + Float32Array, + Float32Array, + { f32::tanh } + )), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c6225220182b..f5d17d986328 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -282,7 +282,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Sinh => Arc::new(math_expressions::sinh), BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt), - BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh), BuiltinScalarFunction::Trunc => { Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index a677e126cb38..db8855cb5400 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -161,7 +161,6 @@ math_unary_function!("sin", sin); math_unary_function!("cos", cos); math_unary_function!("sinh", sinh); math_unary_function!("cosh", cosh); -math_unary_function!("tanh", tanh); math_unary_function!("asin", asin); math_unary_function!("acos", acos); math_unary_function!("atan", atan); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 66605cf79119..8ccb548de586 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -628,7 +628,7 @@ enum ScalarFunction { Atanh = 76; Sinh = 77; Cosh = 78; - Tanh = 79; + // Tanh = 79; Pi = 80; Degrees = 81; Radians = 82; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0742b9658c5a..8984f02569a3 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22163,7 +22163,6 @@ impl serde::Serialize for ScalarFunction { Self::Atanh => "Atanh", Self::Sinh => "Sinh", Self::Cosh => "Cosh", - Self::Tanh => "Tanh", Self::Pi => "Pi", Self::Degrees => "Degrees", Self::Radians => "Radians", @@ -22281,7 +22280,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Atanh", "Sinh", "Cosh", - "Tanh", "Pi", "Degrees", "Radians", @@ -22428,7 +22426,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Atanh" => Ok(ScalarFunction::Atanh), "Sinh" => Ok(ScalarFunction::Sinh), "Cosh" => Ok(ScalarFunction::Cosh), - "Tanh" => Ok(ScalarFunction::Tanh), "Pi" => Ok(ScalarFunction::Pi), "Degrees" => Ok(ScalarFunction::Degrees), "Radians" => Ok(ScalarFunction::Radians), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d930c0cac9e4..d06b71075881 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2700,7 +2700,7 @@ pub enum ScalarFunction { Atanh = 76, Sinh = 77, Cosh = 78, - Tanh = 79, + /// Tanh = 79; Pi = 80, Degrees = 81, Radians = 82, @@ -2833,7 +2833,6 @@ impl ScalarFunction { ScalarFunction::Atanh => "Atanh", ScalarFunction::Sinh => "Sinh", ScalarFunction::Cosh => "Cosh", - ScalarFunction::Tanh => "Tanh", ScalarFunction::Pi => "Pi", ScalarFunction::Degrees => "Degrees", ScalarFunction::Radians => "Radians", @@ -2945,7 +2944,6 @@ impl ScalarFunction { "Atanh" => Some(Self::Atanh), "Sinh" => Some(Self::Sinh), "Cosh" => Some(Self::Cosh), - "Tanh" => Some(Self::Tanh), "Pi" => Some(Self::Pi), "Degrees" => Some(Self::Degrees), "Radians" => Some(Self::Radians), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 37a3847ddbbe..5feef4c53f4b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -62,8 +62,8 @@ use datafusion_expr::{ lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians, random, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, string_to_array, strpos, - struct_fun, substr, substr_index, substring, tanh, to_hex, translate, trim, trunc, - upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + struct_fun, substr, substr_index, substring, to_hex, translate, trim, trunc, upper, + uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -452,7 +452,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Atan => Self::Atan, ScalarFunction::Sinh => Self::Sinh, ScalarFunction::Cosh => Self::Cosh, - ScalarFunction::Tanh => Self::Tanh, ScalarFunction::Asinh => Self::Asinh, ScalarFunction::Acosh => Self::Acosh, ScalarFunction::Atanh => Self::Atanh, @@ -1521,7 +1520,6 @@ pub fn parse_expr( ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Atanh => { Ok(atanh(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index fd2084e35ea4..6036af24f861 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1429,7 +1429,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Sinh => Self::Sinh, BuiltinScalarFunction::Cosh => Self::Cosh, - BuiltinScalarFunction::Tanh => Self::Tanh, BuiltinScalarFunction::Atan => Self::Atan, BuiltinScalarFunction::Asinh => Self::Asinh, BuiltinScalarFunction::Acosh => Self::Acosh, From 3f0f66785448599c3b4504f49b59b0ecfeb8d2c0 Mon Sep 17 00:00:00 2001 From: Chih Wang Date: Mon, 11 Mar 2024 22:52:17 +0800 Subject: [PATCH 3/3] Remove length checking of arg_types in return_type --- datafusion/functions/src/math/tan.rs | 14 +------------- datafusion/functions/src/math/tanh.rs | 14 +------------- 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/datafusion/functions/src/math/tan.rs b/datafusion/functions/src/math/tan.rs index 4822f3925e76..ea3e002f8489 100644 --- a/datafusion/functions/src/math/tan.rs +++ b/datafusion/functions/src/math/tan.rs @@ -22,8 +22,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; use arrow_array::{ArrayRef, Float32Array, Float64Array}; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; -use datafusion_expr::utils::generate_signature_error_msg; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Volatility; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature}; @@ -58,17 +57,6 @@ impl ScalarUDFImpl for TanFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return Err(plan_datafusion_err!( - "{}", - generate_signature_error_msg( - self.name(), - self.signature().clone(), - arg_types, - ) - )); - } - let arg_type = &arg_types[0]; match arg_type { diff --git a/datafusion/functions/src/math/tanh.rs b/datafusion/functions/src/math/tanh.rs index 01629d68ca63..af34681919ab 100644 --- a/datafusion/functions/src/math/tanh.rs +++ b/datafusion/functions/src/math/tanh.rs @@ -22,8 +22,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; use arrow_array::{ArrayRef, Float32Array, Float64Array}; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; -use datafusion_expr::utils::generate_signature_error_msg; +use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Volatility; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature}; @@ -58,17 +57,6 @@ impl ScalarUDFImpl for TanhFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return Err(plan_datafusion_err!( - "{}", - generate_signature_error_msg( - self.name(), - self.signature().clone(), - arg_types, - ) - )); - } - let arg_type = &arg_types[0]; match arg_type {