Skip to content

Commit

Permalink
Port tanh function
Browse files Browse the repository at this point in the history
  • Loading branch information
ongchi committed Mar 10, 2024
1 parent e9209e8 commit b553e85
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 24 deletions.
7 changes: 0 additions & 7 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ pub enum BuiltinScalarFunction {
Sinh,
/// sqrt
Sqrt,
/// tanh
Tanh,
/// trunc
Trunc,
/// cot
Expand Down Expand Up @@ -340,7 +338,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Sqrt => Volatility::Immutable,
BuiltinScalarFunction::Cbrt => Volatility::Immutable,
BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayAppend => Volatility::Immutable,
BuiltinScalarFunction::ArraySort => Volatility::Immutable,
Expand Down Expand Up @@ -738,7 +735,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sinh
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Cbrt
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Cot => match input_expr_types[0] {
Float32 => Ok(Float32),
Expand Down Expand Up @@ -1047,7 +1043,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sin
| BuiltinScalarFunction::Sinh
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Cot => {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
Expand Down Expand Up @@ -1097,7 +1092,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sinh
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Cbrt
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Pi
) {
Expand Down Expand Up @@ -1143,7 +1137,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Sin => &["sin"],
BuiltinScalarFunction::Sinh => &["sinh"],
BuiltinScalarFunction::Sqrt => &["sqrt"],
BuiltinScalarFunction::Tanh => &["tanh"],
BuiltinScalarFunction::Trunc => &["trunc"],

// conditional functions
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ scalar_expr!(Cos, cos, num, "cosine");
scalar_expr!(Cot, cot, num, "cotangent");
scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
scalar_expr!(Tanh, tanh, num, "hyperbolic tangent");
scalar_expr!(Atan, atan, num, "inverse tangent");
scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine");
scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine");
Expand Down Expand Up @@ -1241,7 +1240,6 @@ mod test {
test_unary_scalar_expr!(Cot, cot);
test_unary_scalar_expr!(Sinh, sinh);
test_unary_scalar_expr!(Cosh, cosh);
test_unary_scalar_expr!(Tanh, tanh);
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Asinh, asinh);
test_unary_scalar_expr!(Acosh, acosh);
Expand Down
5 changes: 4 additions & 1 deletion datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ mod acos;
mod asin;
mod nans;
mod tan;
mod tanh;

// create UDFs
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
make_udf_function!(abs::AbsFunc, ABS, abs);
make_udf_function!(acos::AcosFunc, ACOS, acos);
make_udf_function!(asin::AsinFunc, ASIN, asin);
make_udf_function!(tan::TanFunc, TAN, tan);
make_udf_function!(tanh::TanhFunc, TANH, tanh);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
Expand All @@ -48,5 +50,6 @@ export_functions!(
num,
"returns the arc sine or inverse sine of a number"
),
(tan, num, "returns the tangent of a number")
(tan, num, "returns the tangent of a number"),
(tanh, num, "returns the hyperbolic tangent of a number")
);
110 changes: 110 additions & 0 deletions datafusion/functions/src/math/tanh.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Math function: `tanh()`.
use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::DataType;
use arrow_array::{ArrayRef, Float32Array, Float64Array};
use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::utils::generate_signature_error_msg;
use datafusion_expr::Volatility;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature};

#[derive(Debug)]
pub struct TanhFunc {
signature: Signature,
}

impl TanhFunc {
pub fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Float64, DataType::Float32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for TanhFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"tanh"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return Err(plan_datafusion_err!(
"{}",
generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}

let arg_type = &arg_types[0];

match arg_type {
DataType::Float64 => Ok(DataType::Float64),
DataType::Float32 => Ok(DataType::Float32),

// For other types (possible values null/int), use Float 64
_ => Ok(DataType::Float64),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float64Array,
Float64Array,
{ f64::tanh }
)),
DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float32Array,
Float32Array,
{ f32::tanh }
)),
other => {
return exec_err!(
"Unsupported data type {other:?} for function {}",
self.name()
)
}
};
Ok(ColumnarValue::Array(arr))
}
}
1 change: 0 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Sinh => Arc::new(math_expressions::sinh),
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt),
BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh),
BuiltinScalarFunction::Trunc => {
Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args))
}
Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ math_unary_function!("sin", sin);
math_unary_function!("cos", cos);
math_unary_function!("sinh", sinh);
math_unary_function!("cosh", cosh);
math_unary_function!("tanh", tanh);
math_unary_function!("asin", asin);
math_unary_function!("acos", acos);
math_unary_function!("atan", atan);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ enum ScalarFunction {
Atanh = 76;
Sinh = 77;
Cosh = 78;
Tanh = 79;
// Tanh = 79;
Pi = 80;
Degrees = 81;
Radians = 82;
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ use datafusion_expr::{
lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians,
random, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384,
sha512, signum, sin, sinh, split_part, sqrt, starts_with, string_to_array, strpos,
struct_fun, substr, substr_index, substring, tanh, to_hex, translate, trim, trunc,
upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction,
struct_fun, substr, substr_index, substring, to_hex, translate, trim, trunc, upper,
uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction,
BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField,
GroupingSet,
GroupingSet::GroupingSets,
Expand Down Expand Up @@ -452,7 +452,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Atan => Self::Atan,
ScalarFunction::Sinh => Self::Sinh,
ScalarFunction::Cosh => Self::Cosh,
ScalarFunction::Tanh => Self::Tanh,
ScalarFunction::Asinh => Self::Asinh,
ScalarFunction::Acosh => Self::Acosh,
ScalarFunction::Atanh => Self::Atanh,
Expand Down Expand Up @@ -1521,7 +1520,6 @@ pub fn parse_expr(
ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)),
ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)),
ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)),
ScalarFunction::Tanh => Ok(tanh(parse_expr(&args[0], registry, codec)?)),
ScalarFunction::Atanh => {
Ok(atanh(parse_expr(&args[0], registry, codec)?))
}
Expand Down
1 change: 0 additions & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Cot => Self::Cot,
BuiltinScalarFunction::Sinh => Self::Sinh,
BuiltinScalarFunction::Cosh => Self::Cosh,
BuiltinScalarFunction::Tanh => Self::Tanh,
BuiltinScalarFunction::Atan => Self::Atan,
BuiltinScalarFunction::Asinh => Self::Asinh,
BuiltinScalarFunction::Acosh => Self::Acosh,
Expand Down

0 comments on commit b553e85

Please sign in to comment.