Skip to content

Commit

Permalink
feat:implement sql style 'substr_index' string function (apache#8272)
Browse files Browse the repository at this point in the history
* feat:implement sql style 'substr_index' string function

* code format

* code format

* code format

* fix index bound issue

* code format

* code format

* add args len check

* add sql tests

* code format

* doc format
  • Loading branch information
Syleechan authored Nov 26, 2023
1 parent f29bcf3 commit 234217e
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 3 deletions.
15 changes: 15 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ pub enum BuiltinScalarFunction {
OverLay,
/// levenshtein
Levenshtein,
/// substr_index
SubstrIndex,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -470,6 +472,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
BuiltinScalarFunction::OverLay => Volatility::Immutable,
BuiltinScalarFunction::Levenshtein => Volatility::Immutable,
BuiltinScalarFunction::SubstrIndex => Volatility::Immutable,

// Stable builtin functions
BuiltinScalarFunction::Now => Volatility::Stable,
Expand Down Expand Up @@ -773,6 +776,9 @@ impl BuiltinScalarFunction {
return plan_err!("The to_hex function can only accept integers.");
}
}),
BuiltinScalarFunction::SubstrIndex => {
utf8_to_str_type(&input_expr_types[0], "substr_index")
}
BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] {
Int64 => Timestamp(Second, None),
_ => Timestamp(Nanosecond, None),
Expand Down Expand Up @@ -1235,6 +1241,14 @@ impl BuiltinScalarFunction {
self.volatility(),
),

BuiltinScalarFunction::SubstrIndex => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
self.volatility(),
),

BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => {
Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility())
}
Expand Down Expand Up @@ -1486,6 +1500,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Upper => &["upper"],
BuiltinScalarFunction::Uuid => &["uuid"],
BuiltinScalarFunction::Levenshtein => &["levenshtein"],
BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"],

// regex functions
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ scalar_expr!(

scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings");
scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter");

scalar_expr!(
Struct,
Expand Down Expand Up @@ -1205,6 +1206,7 @@ mod test {
test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count);
}

#[test]
Expand Down
23 changes: 23 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,29 @@ pub fn create_physical_fun(
))),
})
}
BuiltinScalarFunction::SubstrIndex => {
Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
substr_index,
i32,
"substr_index"
);
make_scalar_function(func)(args)
}
DataType::LargeUtf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
substr_index,
i64,
"substr_index"
);
make_scalar_function(func)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function substr_index",
))),
})
}
})
}

Expand Down
65 changes: 65 additions & 0 deletions datafusion/physical-expr/src/unicode_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,68 @@ pub fn translate<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

Ok(Arc::new(result) as ArrayRef)
}

/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www
/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return internal_err!(
"substr_index was called with {} arguments. It requires 3.",
args.len()
);
}

let string_array = as_generic_string_array::<T>(&args[0])?;
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
let count_array = as_int64_array(&args[2])?;

let result = string_array
.iter()
.zip(delimiter_array.iter())
.zip(count_array.iter())
.map(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
let mut res = String::new();
match n {
0 => {
"".to_string();
}
_other => {
if n > 0 {
let idx = string
.split(delimiter)
.take(n as usize)
.fold(0, |len, x| len + x.len() + delimiter.len())
- delimiter.len();
res.push_str(if idx >= string.len() {
string
} else {
&string[..idx]
});
} else {
let idx = (string.split(delimiter).take((-n) as usize).fold(
string.len() as isize,
|len, x| {
len - x.len() as isize - delimiter.len() as isize
},
) + delimiter.len() as isize)
as usize;
res.push_str(if idx >= string.len() {
string
} else {
&string[idx..]
});
}
}
}
Some(res)
}
_ => None,
})
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ enum ScalarFunction {
ArrayExcept = 123;
ArrayPopFront = 124;
Levenshtein = 125;
SubstrIndex = 126;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ use datafusion_expr::{
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part,
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh,
to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
to_timestamp_seconds, translate, trim, trunc, upper, uuid,
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index,
substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis,
to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid,
window_frame::regularize,
AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction,
Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet,
Expand Down Expand Up @@ -551,6 +551,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
ScalarFunction::OverLay => Self::OverLay,
ScalarFunction::Levenshtein => Self::Levenshtein,
ScalarFunction::SubstrIndex => Self::SubstrIndex,
}
}
}
Expand Down Expand Up @@ -1716,6 +1717,11 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::SubstrIndex => Ok(substr_index(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::StructFun => {
Ok(struct_fun(parse_expr(&args[0], registry)?))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
BuiltinScalarFunction::OverLay => Self::OverLay,
BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex,
};

Ok(scalar_function)
Expand Down
75 changes: 75 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -877,3 +877,78 @@ query ?
SELECT levenshtein(NULL, NULL)
----
NULL

query T
SELECT substr_index('www.apache.org', '.', 1)
----
www

query T
SELECT substr_index('www.apache.org', '.', 2)
----
www.apache

query T
SELECT substr_index('www.apache.org', '.', -1)
----
org

query T
SELECT substr_index('www.apache.org', '.', -2)
----
apache.org

query T
SELECT substr_index('www.apache.org', 'ac', 1)
----
www.ap

query T
SELECT substr_index('www.apache.org', 'ac', -1)
----
he.org

query T
SELECT substr_index('www.apache.org', 'ac', 2)
----
www.apache.org

query T
SELECT substr_index('www.apache.org', 'ac', -2)
----
www.apache.org

query ?
SELECT substr_index(NULL, 'ac', 1)
----
NULL

query T
SELECT substr_index('www.apache.org', NULL, 1)
----
NULL

query T
SELECT substr_index('www.apache.org', 'ac', NULL)
----
NULL

query T
SELECT substr_index('', 'ac', 1)
----
(empty)

query T
SELECT substr_index('www.apache.org', '', 1)
----
(empty)

query T
SELECT substr_index('www.apache.org', 'ac', 0)
----
(empty)

query ?
SELECT substr_index(NULL, NULL, NULL)
----
NULL
18 changes: 18 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ nullif(expression1, expression2)
- [uuid](#uuid)
- [overlay](#overlay)
- [levenshtein](#levenshtein)
- [substr_index](#substr_index)

### `ascii`

Expand Down Expand Up @@ -1152,6 +1153,23 @@ levenshtein(str1, str2)
- **str1**: String expression to compute Levenshtein distance with str2.
- **str2**: String expression to compute Levenshtein distance with str1.

### `substr_index`

Returns the substring from str before count occurrences of the delimiter delim.
If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org`

```
substr_index(str, delim, count)
```

#### Arguments

- **str**: String expression to operate on.
- **delim**: the string to find in str to split str.
- **count**: The number of times to search for the delimiter. Can be both a positive or negative number.

## Binary String Functions

- [decode](#decode)
Expand Down

0 comments on commit 234217e

Please sign in to comment.