Skip to content

Commit

Permalink
Move nullif and isnan to datafusion-functions (#9216)
Browse files Browse the repository at this point in the history
* Move nullif and isnan to datafusion-functions

move nullif

* complete

* update tests

* fix
  • Loading branch information
alamb authored Feb 14, 2024
1 parent ac919d5 commit e1f7b24
Show file tree
Hide file tree
Showing 22 changed files with 353 additions and 105 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ jobs:
- name: Check function packages (encoding_expressions)
run: cargo check --no-default-features --features=encoding_expressions -p datafusion

- name: Check function packages (math_expressions)
run: cargo check --no-default-features --features=math_expressions -p datafusion

- name: Check function packages (array_expressions)
run: cargo check --no-default-features --features=array_expressions -p datafusion

Expand Down
1 change: 1 addition & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ default = ["array_expressions", "crypto_expressions", "encoding_expressions", "r
encoding_expressions = ["datafusion-functions/encoding_expressions"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
force_hash_collisions = []
math_expressions = ["datafusion-functions/math_expressions"]
parquet = ["datafusion-common/parquet", "dep:parquet"]
pyarrow = ["datafusion-common/pyarrow", "parquet"]
regex_expressions = ["datafusion-physical-expr/regex_expressions", "datafusion-optimizer/regex_expressions"]
Expand Down
29 changes: 5 additions & 24 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use std::fmt;
use std::str::FromStr;
use std::sync::{Arc, OnceLock};

use crate::nullif::SUPPORTED_NULLIF_TYPES;
use crate::signature::TIMEZONE_WILDCARD;
use crate::type_coercion::binary::get_wider_type;
use crate::type_coercion::functions::data_types;
Expand Down Expand Up @@ -83,8 +82,6 @@ pub enum BuiltinScalarFunction {
Gcd,
/// lcm, Least common multiple
Lcm,
/// isnan
Isnan,
/// iszero
Iszero,
/// ln, Natural logarithm
Expand Down Expand Up @@ -233,8 +230,6 @@ pub enum BuiltinScalarFunction {
Ltrim,
/// md5
MD5,
/// nullif
NullIf,
/// octet_length
OctetLength,
/// random
Expand Down Expand Up @@ -386,7 +381,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Floor => Volatility::Immutable,
BuiltinScalarFunction::Gcd => Volatility::Immutable,
BuiltinScalarFunction::Isnan => Volatility::Immutable,
BuiltinScalarFunction::Iszero => Volatility::Immutable,
BuiltinScalarFunction::Lcm => Volatility::Immutable,
BuiltinScalarFunction::Ln => Volatility::Immutable,
Expand Down Expand Up @@ -458,7 +452,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Lower => Volatility::Immutable,
BuiltinScalarFunction::Ltrim => Volatility::Immutable,
BuiltinScalarFunction::MD5 => Volatility::Immutable,
BuiltinScalarFunction::NullIf => Volatility::Immutable,
BuiltinScalarFunction::OctetLength => Volatility::Immutable,
BuiltinScalarFunction::Radians => Volatility::Immutable,
BuiltinScalarFunction::RegexpLike => Volatility::Immutable,
Expand Down Expand Up @@ -729,11 +722,6 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "ltrim")
}
BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"),
BuiltinScalarFunction::NullIf => {
// NULLIF has two args and they might get coerced, get a preview of this
let coerced_types = data_types(input_expr_types, &self.signature());
coerced_types.map(|typs| typs[0].clone())
}
BuiltinScalarFunction::OctetLength => {
utf8_to_int_type(&input_expr_types[0], "octet_length")
}
Expand Down Expand Up @@ -875,7 +863,7 @@ impl BuiltinScalarFunction {
_ => Ok(Float64),
},

BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => Ok(Boolean),
BuiltinScalarFunction::Iszero => Ok(Boolean),

BuiltinScalarFunction::ArrowTypeof => Ok(Utf8),

Expand Down Expand Up @@ -1300,9 +1288,6 @@ impl BuiltinScalarFunction {
self.volatility(),
),

BuiltinScalarFunction::NullIf => {
Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), self.volatility())
}
BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Uuid => Signature::exact(vec![], self.volatility()),
Expand Down Expand Up @@ -1407,12 +1392,10 @@ impl BuiltinScalarFunction {
vec![Int32, Int64, UInt32, UInt64, Utf8],
self.volatility(),
),
BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => {
Signature::one_of(
vec![Exact(vec![Float32]), Exact(vec![Float64])],
self.volatility(),
)
}
BuiltinScalarFunction::Iszero => Signature::one_of(
vec![Exact(vec![Float32]), Exact(vec![Float64])],
self.volatility(),
),
}
}

Expand Down Expand Up @@ -1478,7 +1461,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Floor => &["floor"],
BuiltinScalarFunction::Gcd => &["gcd"],
BuiltinScalarFunction::Isnan => &["isnan"],
BuiltinScalarFunction::Iszero => &["iszero"],
BuiltinScalarFunction::Lcm => &["lcm"],
BuiltinScalarFunction::Ln => &["ln"],
Expand All @@ -1501,7 +1483,6 @@ impl BuiltinScalarFunction {

// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],
BuiltinScalarFunction::NullIf => &["nullif"],

// string functions
BuiltinScalarFunction::Ascii => &["ascii"],
Expand Down
8 changes: 0 additions & 8 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,6 @@ scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple");
scalar_expr!(Log2, log2, num, "base 2 logarithm");
scalar_expr!(Log10, log10, num, "base 10 logarithm");
scalar_expr!(Ln, ln, num, "natural logarithm");
scalar_expr!(NullIf, nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression.");
scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`");
scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument");
scalar_expr!(
Expand Down Expand Up @@ -932,12 +931,6 @@ scalar_expr!(Now, now, ,"returns current timestamp in nanoseconds, using the sam
scalar_expr!(CurrentTime, current_time, , "returns current UTC time as a [`DataType::Time64`] value");
scalar_expr!(MakeDate, make_date, year month day, "make a date from year, month and day component parts");
scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y");
scalar_expr!(
Isnan,
isnan,
num,
"returns true if a given number is +NaN or -NaN otherwise returns false"
);
scalar_expr!(
Iszero,
iszero,
Expand Down Expand Up @@ -1369,7 +1362,6 @@ mod test {
test_unary_scalar_expr!(Ln, ln);
test_scalar_expr!(Atan2, atan2, y, x);
test_scalar_expr!(Nanvl, nanvl, x, y);
test_scalar_expr!(Isnan, isnan, input);
test_scalar_expr!(Iszero, iszero, input);

test_scalar_expr!(Ascii, ascii, input);
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ mod built_in_function;
mod built_in_window_function;
mod columnar_value;
mod literal;
mod nullif;
mod operator;
mod partition_evaluator;
mod signature;
Expand Down Expand Up @@ -74,7 +73,6 @@ pub use function::{
pub use groups_accumulator::{EmitTo, GroupsAccumulator};
pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
pub use logical_plan::*;
pub use nullif::SUPPORTED_NULLIF_TYPES;
pub use operator::Operator;
pub use partition_evaluator::PartitionEvaluator;
pub use signature::{
Expand Down
8 changes: 6 additions & 2 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ authors = { workspace = true }
rust-version = { workspace = true }

[features]
# enable core functions
core_expressions = []
# Enable encoding by default so the doctests work. In general don't automatically enable all packages.
default = ["encoding_expressions"]
# enable the encode/decode functions
default = ["core_expressions", "encoding_expressions", "math_expressions"]
# enable encode/decode functions
encoding_expressions = ["base64", "hex"]
# enable math functions
math_expressions = []


[lib]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
//! "core" DataFusion functions
mod nullif;

// create UDFs
make_udf_function!(nullif::NullIfFunc, NULLIF, nullif);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression.")
);

/// Currently supported types by the nullif function.
/// The order of these types correspond to the order on which coercion applies
/// This should thus be from least informative to most informative
pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[
DataType::Boolean,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
];
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,89 @@
// specific language governing permissions and limitations
// under the License.

//! Encoding expressions
use arrow::{
datatypes::DataType,
};
use datafusion_common::{internal_err, Result, DataFusionError};
use datafusion_expr::{ColumnarValue};

use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use arrow::array::Array;
use arrow::compute::kernels::cmp::eq;
use arrow::compute::kernels::nullif::nullif;
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use datafusion_common::{ ScalarValue};

#[derive(Debug)]
pub(super) struct NullIfFunc {
signature: Signature,
}

/// Currently supported types by the nullif function.
/// The order of these types correspond to the order on which coercion applies
/// This should thus be from least informative to most informative
static SUPPORTED_NULLIF_TYPES: &[DataType] = &[
DataType::Boolean,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
];


impl NullIfFunc {
pub fn new() -> Self {
Self {
signature:
Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(),
Volatility::Immutable,
)
}
}
}

impl ScalarUDFImpl for NullIfFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"nullif"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
// NULLIF has two args and they might get coerced, get a preview of this
let coerced_types = datafusion_expr::type_coercion::functions::data_types(arg_types, &self.signature);
coerced_types.map(|typs| typs[0].clone())
.map_err(|e| e.context("Failed to coerce arguments for NULLIF")
)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
nullif_func(args)
}
}



/// Implements NULLIF(expr1, expr2)
/// Args: 0 - left expr is any array
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
///
pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return internal_err!(
"{:?} args were supplied but NULLIF takes exactly two args",
Expand Down
15 changes: 14 additions & 1 deletion datafusion/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,34 @@ use log::debug;
#[macro_use]
pub mod macros;

make_package!(core, "core_expressions", "Core datafusion expressions");

make_package!(
encoding,
"encoding_expressions",
"Hex and binary `encode` and `decode` functions."
);

make_package!(math, "math_expressions", "Mathematical functions.");

/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
#[cfg(feature = "core_expressions")]
pub use super::core::expr_fn::*;
#[cfg(feature = "encoding_expressions")]
pub use super::encoding::expr_fn::*;
#[cfg(feature = "math_expressions")]
pub use super::math::expr_fn::*;
}

/// Registers all enabled packages with a [`FunctionRegistry`]
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
encoding::functions().into_iter().try_for_each(|udf| {
let mut all_functions = core::functions()
.into_iter()
.chain(encoding::functions())
.chain(math::functions());

all_functions.try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
if let Some(existing_udf) = existing_udf {
debug!("Overwrite existing UDF: {}", existing_udf.name());
Expand Down
39 changes: 39 additions & 0 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,42 @@ macro_rules! make_package {
}
};
}

/// Invokes a function on each element of an array and returns the result as a new array
///
/// $ARG: ArrayRef
/// $NAME: name of the function (for error messages)
/// $ARGS_TYPE: the type of array to cast the argument to
/// $RETURN_TYPE: the type of array to return
/// $FUNC: the function to apply to each element of $ARG
///
macro_rules! make_function_scalar_inputs_return_type {
($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$RETURN_TYPE>()
}};
}

/// Downcast an argument to a specific array type, returning an internal error
/// if the cast fails
///
/// $ARG: ArrayRef
/// $NAME: name of the argument (for error messages)
/// $ARRAY_TYPE: the type of array to cast the argument to
macro_rules! downcast_arg {
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast {} to {}",
$NAME,
std::any::type_name::<$ARRAY_TYPE>()
))
})?
}};
}
Loading

0 comments on commit e1f7b24

Please sign in to comment.