Skip to content

Commit

Permalink
feat:implement sql style 'ends_with' and 'instr' string function
Browse files Browse the repository at this point in the history
  • Loading branch information
zy-kkk committed Jan 14, 2024
1 parent be361fd commit 4d3168f
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 14 deletions.
40 changes: 40 additions & 0 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,26 @@ async fn test_fn_initcap() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_fn_instr() -> Result<()> {
let expr = instr(col("a"), lit("b"));

let expected = [
"+-------------------------+",
"| instr(test.a,Utf8(\"b\")) |",
"+-------------------------+",
"| 2 |",
"| 2 |",
"| 0 |",
"| 5 |",
"+-------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
#[cfg(feature = "unicode_expressions")]
async fn test_fn_left() -> Result<()> {
Expand Down Expand Up @@ -634,6 +654,26 @@ async fn test_fn_starts_with() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_fn_ends_with() -> Result<()> {
let expr = ends_with(col("a"), lit("DEF"));

let expected = [
"+-------------------------------+",
"| ends_with(test.a,Utf8(\"DEF\")) |",
"+-------------------------------+",
"| true |",
"| false |",
"| false |",
"| false |",
"+-------------------------------+",
];

assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
#[cfg(feature = "unicode_expressions")]
async fn test_fn_strpos() -> Result<()> {
Expand Down
36 changes: 25 additions & 11 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ pub enum BuiltinScalarFunction {
DateTrunc,
/// date_bin
DateBin,
/// ends_with
EndsWith,
/// initcap
InitCap,
/// InStr
InStr,
/// left
Left,
/// lpad
Expand Down Expand Up @@ -446,7 +450,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::DatePart => Volatility::Immutable,
BuiltinScalarFunction::DateTrunc => Volatility::Immutable,
BuiltinScalarFunction::DateBin => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
BuiltinScalarFunction::InitCap => Volatility::Immutable,
BuiltinScalarFunction::InStr => Volatility::Immutable,
BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Lower => Volatility::Immutable,
Expand Down Expand Up @@ -708,6 +714,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::InStr => {
utf8_to_int_type(&input_expr_types[0], "instr")
}
BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"),
BuiltinScalarFunction::Lower => {
utf8_to_str_type(&input_expr_types[0], "lower")
Expand Down Expand Up @@ -795,6 +804,7 @@ impl BuiltinScalarFunction {
true,
)))),
BuiltinScalarFunction::StartsWith => Ok(Boolean),
BuiltinScalarFunction::EndsWith => Ok(Boolean),
BuiltinScalarFunction::Strpos => {
utf8_to_int_type(&input_expr_types[0], "strpos")
}
Expand Down Expand Up @@ -1262,17 +1272,19 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => {
Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
)
}

BuiltinScalarFunction::EndsWith
| BuiltinScalarFunction::InStr
| BuiltinScalarFunction::Strpos
| BuiltinScalarFunction::StartsWith => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
),

BuiltinScalarFunction::Substr => Signature::one_of(
vec![
Expand Down Expand Up @@ -1524,7 +1536,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Concat => &["concat"],
BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"],
BuiltinScalarFunction::Chr => &["chr"],
BuiltinScalarFunction::EndsWith => &["ends_with"],
BuiltinScalarFunction::InitCap => &["initcap"],
BuiltinScalarFunction::InStr => &["instr"],
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lower => &["lower"],
BuiltinScalarFunction::Lpad => &["lpad"],
Expand Down
4 changes: 4 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input
scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex");
scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex");
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(InStr, instr, string substring, "returns the position of the first occurrence of `substring` in `string`");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Lower, lower, string, "convert the string to lower case");
scalar_expr!(
Expand Down Expand Up @@ -830,6 +831,7 @@ scalar_expr!(SHA512, sha512, string, "SHA-512 hash");
scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index.");
scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`");
scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`");
scalar_expr!(Substr, substr, string position, "substring from the `position` to the end");
scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters");
Expand Down Expand Up @@ -1371,6 +1373,7 @@ mod test {
test_scalar_expr!(Gcd, gcd, arg_1, arg_2);
test_scalar_expr!(Lcm, lcm, arg_1, arg_2);
test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(InStr, instr, string, substring);
test_scalar_expr!(Left, left, string, count);
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
Expand Down Expand Up @@ -1409,6 +1412,7 @@ mod test {
test_scalar_expr!(SplitPart, split_part, expr, delimiter, index);
test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value);
test_scalar_expr!(StartsWith, starts_with, string, characters);
test_scalar_expr!(EndsWith, ends_with, string, characters);
test_scalar_expr!(Strpos, strpos, string, substring);
test_scalar_expr!(Substr, substr, string, position);
test_scalar_expr!(Substr, substring, string, position, count);
Expand Down
92 changes: 92 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,15 @@ pub fn create_physical_fun(
internal_err!("Unsupported data type {other:?} for function initcap")
}
}),
BuiltinScalarFunction::InStr => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::instr::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::instr::<i64>)(args)
}
other => internal_err!("Unsupported data type {other:?} for function instr"),
}),
BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left");
Expand Down Expand Up @@ -765,6 +774,17 @@ pub fn create_physical_fun(
internal_err!("Unsupported data type {other:?} for function starts_with")
}
}),
BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::ends_with::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::ends_with::<i64>)(args)
}
other => {
internal_err!("Unsupported data type {other:?} for function ends_with")
}
}),
BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_unicode_expressions_feature_flag!(
Expand Down Expand Up @@ -1379,6 +1399,46 @@ mod tests {
Utf8,
StringArray
);
test_function!(
InStr,
&[lit("abc"), lit("b")],
Ok(Some(2)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("c")],
Ok(Some(3)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("d")],
Ok(Some(0)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("abc"), lit("")],
Ok(Some(1)),
i32,
Int32,
Int32Array
);
test_function!(
InStr,
&[lit("Helloworld"), lit("world")],
Ok(Some(6)),
i32,
Int32,
Int32Array
);
#[cfg(feature = "unicode_expressions")]
test_function!(
Left,
Expand Down Expand Up @@ -2497,6 +2557,38 @@ mod tests {
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit("alph"),],
Ok(Some(false)),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit("bet"),],
Ok(Some(true)),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit(ScalarValue::Utf8(None)), lit("alph"),],
Ok(None),
bool,
Boolean,
BooleanArray
);
test_function!(
EndsWith,
&[lit("alphabet"), lit(ScalarValue::Utf8(None)),],
Ok(None),
bool,
Boolean,
BooleanArray
);
#[cfg(feature = "unicode_expressions")]
test_function!(
Strpos,
Expand Down
62 changes: 62 additions & 0 deletions datafusion/physical-expr/src/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,50 @@ pub fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Returns the position of the first occurrence of substring in string.
/// The position is counted from 1. If the substring is not found, returns 0.
/// For example, instr('Helloworld', 'world') = 6.
pub fn instr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let substr_array = as_generic_string_array::<T>(&args[1])?;

match args[0].data_type() {
DataType::Utf8 => {
let result = string_array
.iter()
.zip(substr_array.iter())
.map(|(string, substr)| match (string, substr) {
(Some(string), Some(substr)) => {
string.find(substr).map_or(0, |index| (index + 1) as i32)
}
_ => 0,
})
.collect::<Int32Array>();

Ok(Arc::new(result) as ArrayRef)
}
DataType::LargeUtf8 => {
let result = string_array
.iter()
.zip(substr_array.iter())
.map(|(string, substr)| match (string, substr) {
(Some(string), Some(substr)) => {
string.find(substr).map_or(0, |index| (index + 1) as i64)
}
_ => 0,
})
.collect::<Int64Array>();

Ok(Arc::new(result) as ArrayRef)
}
other => {
internal_err!(
"instr was called with {other} datatype arguments. It requires Utf8 or LargeUtf8."
)
}
}
}

/// Converts the string to all lower case.
/// lower('TOM') = 'tom'
pub fn lower(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -476,6 +520,24 @@ pub fn starts_with<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

/// Returns true if string ends with suffix.
/// ends_with('alphabet', 'abet') = 't'
pub fn ends_with<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let suffix_array = as_generic_string_array::<T>(&args[1])?;

let result = string_array
.iter()
.zip(suffix_array.iter())
.map(|(string, suffix)| match (string, suffix) {
(Some(string), Some(suffix)) => Some(string.ends_with(suffix)),
_ => None,
})
.collect::<BooleanArray>();

Ok(Arc::new(result) as ArrayRef)
}

/// Converts the number to its equivalent hexadecimal representation.
/// to_hex(2147483647) = '7fffffff'
pub fn to_hex<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
Expand Down
5 changes: 5 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4d3168f

Please sign in to comment.