diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 1b48c37406d3..fc6f9c28e105 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -298,6 +298,8 @@ pub enum BuiltinScalarFunction { ArrowTypeof, /// overlay OverLay, + /// levenshtein + Levenshtein, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -464,6 +466,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::FromUnixtime => Volatility::Immutable, BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable, BuiltinScalarFunction::OverLay => Volatility::Immutable, + BuiltinScalarFunction::Levenshtein => Volatility::Immutable, // Stable builtin functions BuiltinScalarFunction::Now => Volatility::Stable, @@ -829,6 +832,10 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "overlay") } + BuiltinScalarFunction::Levenshtein => { + utf8_to_int_type(&input_expr_types[0], "levenshtein") + } + BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1293,6 +1300,10 @@ impl BuiltinScalarFunction { ], self.volatility(), ), + BuiltinScalarFunction::Levenshtein => Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + self.volatility(), + ), BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin | BuiltinScalarFunction::Atan @@ -1457,6 +1468,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Trim => &["trim"], BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], + BuiltinScalarFunction::Levenshtein => &["levenshtein"], // regex functions BuiltinScalarFunction::RegexpMatch => &["regexp_match"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index bcf1aa0ca7e5..75b762804427 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -909,6 +909,7 @@ scalar_expr!( ); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); +scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); scalar_expr!( Struct, @@ -1195,6 +1196,7 @@ mod test { test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); + test_scalar_expr!(Levenshtein, levenshtein, string1, string2); } #[test] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 1e8500079f21..b46249d26dde 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -846,6 +846,19 @@ pub fn create_physical_fun( "Unsupported data type {other:?} for function overlay", ))), }), + BuiltinScalarFunction::Levenshtein => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::levenshtein::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function levenshtein", + ))), + }) + } }) } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 7e954fdcfdc4..91d21f95e41f 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,11 +23,12 @@ use arrow::{ array::{ - Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait, - StringArray, + Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, + OffsetSizeTrait, StringArray, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; +use datafusion_common::utils::datafusion_strsim; use datafusion_common::{ cast::{ as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, @@ -643,12 +644,59 @@ pub fn overlay(args: &[ArrayRef]) -> Result { } } +///Returns the Levenshtein distance between the two given strings. +/// LEVENSHTEIN('kitten', 'sitting') = 3 +pub fn levenshtein(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "levenshtein function requires two arguments, got {}", + args.len() + ))); + } + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i64) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + #[cfg(test)] mod tests { use crate::string_expressions; use arrow::{array::Int32Array, datatypes::Int32Type}; use arrow_array::Int64Array; + use datafusion_common::cast::as_int32_array; use super::*; @@ -707,4 +755,19 @@ mod tests { Ok(()) } + + #[test] + fn to_levenshtein() -> Result<()> { + let string1_array = + Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); + let string2_array = + Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); + let res = levenshtein::(&[string1_array, string2_array]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function levenshtein"); + let expected = Int32Array::from(vec![2, 3, 2, 3]); + assert_eq!(&expected, result); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 66c34c7a12ec..a5c3d3b603df 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -639,6 +639,7 @@ enum ScalarFunction { OverLay = 121; Range = 122; ArrayPopFront = 123; + Levenshtein = 124; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 628adcc41189..3faacca18c60 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20938,6 +20938,7 @@ impl serde::Serialize for ScalarFunction { Self::OverLay => "OverLay", Self::Range => "Range", Self::ArrayPopFront => "ArrayPopFront", + Self::Levenshtein => "Levenshtein", }; serializer.serialize_str(variant) } @@ -21073,6 +21074,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "OverLay", "Range", "ArrayPopFront", + "Levenshtein", ]; struct GeneratedVisitor; @@ -21237,6 +21239,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "OverLay" => Ok(ScalarFunction::OverLay), "Range" => Ok(ScalarFunction::Range), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), + "Levenshtein" => Ok(ScalarFunction::Levenshtein), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 317b888447a0..2555a31f6fe2 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2570,6 +2570,7 @@ pub enum ScalarFunction { OverLay = 121, Range = 122, ArrayPopFront = 123, + Levenshtein = 124, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2702,6 +2703,7 @@ impl ScalarFunction { ScalarFunction::OverLay => "OverLay", ScalarFunction::Range => "Range", ScalarFunction::ArrayPopFront => "ArrayPopFront", + ScalarFunction::Levenshtein => "Levenshtein", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2831,6 +2833,7 @@ impl ScalarFunction { "OverLay" => Some(Self::OverLay), "Range" => Some(Self::Range), "ArrayPopFront" => Some(Self::ArrayPopFront), + "Levenshtein" => Some(Self::Levenshtein), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 94c9f9806621..f14da70485ab 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -50,7 +50,7 @@ use datafusion_expr::{ date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, - ln, log, log10, log2, + levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -549,6 +549,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, ScalarFunction::OverLay => Self::OverLay, + ScalarFunction::Levenshtein => Self::Levenshtein, } } } @@ -1630,6 +1631,10 @@ pub fn parse_expr( )) } } + ScalarFunction::Levenshtein => Ok(levenshtein( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)), ScalarFunction::ToTimestampMillis => { Ok(to_timestamp_millis(parse_expr(&args[0], registry)?)) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 53be5f7bd498..192956e92c00 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1556,6 +1556,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, BuiltinScalarFunction::OverLay => Self::OverLay, + BuiltinScalarFunction::Levenshtein => Self::Levenshtein, }; Ok(scalar_function) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 8f4230438480..9c8bb2c5f844 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -788,7 +788,7 @@ INSERT INTO products (product_id, product_name, price) VALUES (1, 'OldBrand Product 1', 19.99), (2, 'OldBrand Product 2', 29.99), (3, 'OldBrand Product 3', 39.99), -(4, 'OldBrand Product 4', 49.99) +(4, 'OldBrand Product 4', 49.99) query ITR SELECT * REPLACE (price*2 AS price) FROM products @@ -857,3 +857,23 @@ NULL NULL Thomxas NULL + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query ? +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query ? +SELECT levenshtein(NULL, NULL) +---- +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index baaea3926f7d..f9f45a1b0a97 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -636,6 +636,7 @@ nullif(expression1, expression2) - [upper](#upper) - [uuid](#uuid) - [overlay](#overlay) +- [levenshtein](#levenshtein) ### `ascii` @@ -1137,6 +1138,20 @@ overlay(str PLACING substr FROM pos [FOR count]) - **pos**: the start position to replace of str. - **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. +### `levenshtein` + +Returns the Levenshtein distance between the two given strings. +For example, `levenshtein('kitten', 'sitting') = 3` + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + ## Binary String Functions - [decode](#decode)