Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:implement calcite style 'levenshtein' string function #8168

Merged
merged 16 commits into from
Nov 17, 2023
12 changes: 12 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ pub enum BuiltinScalarFunction {
ArrowTypeof,
/// overlay
OverLay,
/// levenshtein
Levenshtein,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -464,6 +466,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::FromUnixtime => Volatility::Immutable,
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::OverLay => Volatility::Immutable,
BuiltinScalarFunction::Levenshtein => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -829,6 +832,10 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "overlay")
}

BuiltinScalarFunction::Levenshtein => {
utf8_to_int_type(&input_expr_types[0], "levenshtein")
}

BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1293,6 +1300,10 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Levenshtein => Signature::one_of(
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
self.volatility(),
),
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
Expand Down Expand Up @@ -1457,6 +1468,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Trim => &["trim"],
BuiltinScalarFunction::Upper => &["upper"],
BuiltinScalarFunction::Uuid => &["uuid"],
BuiltinScalarFunction::Levenshtein => &["levenshtein"],

// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ scalar_expr!(
);

scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings");

scalar_expr!(
Struct,
Expand Down Expand Up @@ -1195,6 +1196,7 @@ mod test {
test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
}

#[test]
Expand Down
13 changes: 13 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,19 @@ pub fn create_physical_fun(
"Unsupported data type {other:?} for function overlay",
))),
}),
BuiltinScalarFunction::Levenshtein => {
Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::levenshtein::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::levenshtein::<i64>)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function levenshtein",
))),
})
}
})
}

Expand Down
67 changes: 65 additions & 2 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@

use arrow::{
array::{
Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, OffsetSizeTrait,
StringArray,
Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array,
OffsetSizeTrait, StringArray,
},
datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType},
};
use datafusion_common::utils::datafusion_strsim;
use datafusion_common::{
cast::{
as_generic_string_array, as_int64_array, as_primitive_array, as_string_array,
Expand Down Expand Up @@ -643,12 +644,59 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

///Returns the Levenshtein distance between the two given strings.
/// LEVENSHTEIN('kitten', 'sitting') = 3
pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
alamb marked this conversation as resolved.
Show resolved Hide resolved
if args.len() != 2 {
return Err(DataFusionError::Internal(format!(
"levenshtein function requires two arguments, got {}",
args.len()
)));
}
let str1_array = as_generic_string_array::<T>(&args[0])?;
let str2_array = as_generic_string_array::<T>(&args[1])?;
match args[0].data_type() {
DataType::Utf8 => {
let result = str1_array
.iter()
.zip(str2_array.iter())
.map(|(string1, string2)| match (string1, string2) {
(Some(string1), Some(string2)) => {
Some(datafusion_strsim::levenshtein(string1, string2) as i32)
}
_ => None,
})
.collect::<Int32Array>();
Ok(Arc::new(result) as ArrayRef)
}
DataType::LargeUtf8 => {
let result = str1_array
.iter()
.zip(str2_array.iter())
.map(|(string1, string2)| match (string1, string2) {
(Some(string1), Some(string2)) => {
Some(datafusion_strsim::levenshtein(string1, string2) as i64)
}
_ => None,
})
.collect::<Int64Array>();
Ok(Arc::new(result) as ArrayRef)
}
other => {
internal_err!(
"levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8."
)
}
}
}

#[cfg(test)]
mod tests {

use crate::string_expressions;
use arrow::{array::Int32Array, datatypes::Int32Type};
use arrow_array::Int64Array;
use datafusion_common::cast::as_int32_array;

use super::*;

Expand Down Expand Up @@ -707,4 +755,19 @@ mod tests {

Ok(())
}

#[test]
fn to_levenshtein() -> Result<()> {
let string1_array =
Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"]));
let string2_array =
Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"]));
let res = levenshtein::<i32>(&[string1_array, string2_array]).unwrap();
let result =
as_int32_array(&res).expect("failed to initialized function levenshtein");
let expected = Int32Array::from(vec![2, 3, 2, 3]);
assert_eq!(&expected, result);

Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ enum ScalarFunction {
OverLay = 121;
Range = 122;
ArrayPopFront = 123;
Levenshtein = 124;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use datafusion_expr::{
date_part, date_trunc, decode, degrees, digest, encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left,
ln, log, log10, log2,
levenshtein, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
Expand Down Expand Up @@ -549,6 +549,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Iszero => Self::Iszero,
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
ScalarFunction::OverLay => Self::OverLay,
ScalarFunction::Levenshtein => Self::Levenshtein,
}
}
}
Expand Down Expand Up @@ -1630,6 +1631,10 @@ pub fn parse_expr(
))
}
}
ScalarFunction::Levenshtein => Ok(levenshtein(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ToHex => Ok(to_hex(parse_expr(&args[0], registry)?)),
ScalarFunction::ToTimestampMillis => {
Ok(to_timestamp_millis(parse_expr(&args[0], registry)?))
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Iszero => Self::Iszero,
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
BuiltinScalarFunction::OverLay => Self::OverLay,
BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
};

Ok(scalar_function)
Expand Down
22 changes: 21 additions & 1 deletion datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ INSERT INTO products (product_id, product_name, price) VALUES
(1, 'OldBrand Product 1', 19.99),
(2, 'OldBrand Product 2', 29.99),
(3, 'OldBrand Product 3', 39.99),
(4, 'OldBrand Product 4', 49.99)
(4, 'OldBrand Product 4', 49.99)

query ITR
SELECT * REPLACE (price*2 AS price) FROM products
Expand Down Expand Up @@ -857,3 +857,23 @@ NULL
NULL
Thomxas
NULL

query I
SELECT levenshtein('kitten', 'sitting')
----
3

query I
SELECT levenshtein('kitten', NULL)
----
NULL

query ?
SELECT levenshtein(NULL, 'sitting')
----
NULL

query ?
SELECT levenshtein(NULL, NULL)
----
NULL
15 changes: 15 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ nullif(expression1, expression2)
- [upper](#upper)
- [uuid](#uuid)
- [overlay](#overlay)
- [levenshtein](#levenshtein)

### `ascii`

Expand Down Expand Up @@ -1137,6 +1138,20 @@ overlay(str PLACING substr FROM pos [FOR count])
- **pos**: the start position to replace of str.
- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead.

### `levenshtein`

Returns the Levenshtein distance between the two given strings.
For example, `levenshtein('kitten', 'sitting') = 3`

```
levenshtein(str1, str2)
```

#### Arguments

- **str1**: String expression to compute Levenshtein distance with str2.
- **str2**: String expression to compute Levenshtein distance with str1.

## Binary String Functions

- [decode](#decode)
Expand Down