From 4d3168f07bd02a7297aac6b1ef2434789b196548 Mon Sep 17 00:00:00 2001 From: zy-kkk Date: Sun, 14 Jan 2024 20:50:43 +0800 Subject: [PATCH] feat:implement sql style 'ends_with' and 'instr' string function --- .../tests/dataframe/dataframe_functions.rs | 40 ++++++++ datafusion/expr/src/built_in_function.rs | 36 +++++--- datafusion/expr/src/expr_fn.rs | 4 + datafusion/physical-expr/src/functions.rs | 92 +++++++++++++++++++ .../physical-expr/src/string_expressions.rs | 62 +++++++++++++ datafusion/proto/src/generated/pbjson.rs | 5 + datafusion/proto/src/generated/prost.rs | 6 ++ .../proto/src/logical_plan/from_proto.rs | 16 +++- datafusion/proto/src/logical_plan/to_proto.rs | 2 + .../sqllogictest/test_files/functions.slt | 35 +++++++ .../source/user-guide/sql/scalar_functions.md | 32 +++++++ 11 files changed, 316 insertions(+), 14 deletions(-) diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index fe56fc22ea8cc..2d4203464300d 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -267,6 +267,26 @@ async fn test_fn_initcap() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_instr() -> Result<()> { + let expr = instr(col("a"), lit("b")); + + let expected = [ + "+-------------------------+", + "| instr(test.a,Utf8(\"b\")) |", + "+-------------------------+", + "| 2 |", + "| 2 |", + "| 0 |", + "| 5 |", + "+-------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_left() -> Result<()> { @@ -634,6 +654,26 @@ async fn test_fn_starts_with() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_ends_with() -> Result<()> { + let expr = ends_with(col("a"), lit("DEF")); + + let expected = [ + "+-------------------------------+", + "| ends_with(test.a,Utf8(\"DEF\")) |", + "+-------------------------------+", + "| true |", + "| false |", + "| false |", + "| false |", + "+-------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_strpos() -> Result<()> { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 6f64642f60d9b..882eb9af26d68 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -221,8 +221,12 @@ pub enum BuiltinScalarFunction { DateTrunc, /// date_bin DateBin, + /// ends_with + EndsWith, /// initcap InitCap, + /// InStr + InStr, /// left Left, /// lpad @@ -446,7 +450,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::DatePart => Volatility::Immutable, BuiltinScalarFunction::DateTrunc => Volatility::Immutable, BuiltinScalarFunction::DateBin => Volatility::Immutable, + BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, + BuiltinScalarFunction::InStr => Volatility::Immutable, BuiltinScalarFunction::Left => Volatility::Immutable, BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Lower => Volatility::Immutable, @@ -708,6 +714,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } + BuiltinScalarFunction::InStr => { + utf8_to_int_type(&input_expr_types[0], "instr") + } BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), BuiltinScalarFunction::Lower => { utf8_to_str_type(&input_expr_types[0], "lower") @@ -795,6 +804,7 @@ impl BuiltinScalarFunction { true, )))), BuiltinScalarFunction::StartsWith => Ok(Boolean), + BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos") } @@ -1262,17 +1272,19 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ) - } + + BuiltinScalarFunction::EndsWith + | BuiltinScalarFunction::InStr + | BuiltinScalarFunction::Strpos + | BuiltinScalarFunction::StartsWith => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + self.volatility(), + ), BuiltinScalarFunction::Substr => Signature::one_of( vec![ @@ -1524,7 +1536,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Concat => &["concat"], BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::Chr => &["chr"], + BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], + BuiltinScalarFunction::InStr => &["instr"], BuiltinScalarFunction::Left => &["left"], BuiltinScalarFunction::Lower => &["lower"], BuiltinScalarFunction::Lpad => &["lpad"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 834420e413b0d..3c5b814761b08 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -798,6 +798,7 @@ scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"); scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); +scalar_expr!(InStr, instr, string substring, "returns the position of the first occurrence of `substring` in `string`"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Lower, lower, string, "convert the string to lower case"); scalar_expr!( @@ -830,6 +831,7 @@ scalar_expr!(SHA512, sha512, string, "SHA-512 hash"); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`"); scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`"); +scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); @@ -1371,6 +1373,7 @@ mod test { test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); + test_scalar_expr!(InStr, instr, string, substring); test_scalar_expr!(Left, left, string, count); test_scalar_expr!(Lower, lower, string); test_nary_scalar_expr!(Lpad, lpad, string, count); @@ -1409,6 +1412,7 @@ mod test { test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value); test_scalar_expr!(StartsWith, starts_with, string, characters); + test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); test_scalar_expr!(Substr, substring, string, position, count); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 66e22d2302de2..3371c7a3b25c0 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -543,6 +543,15 @@ pub fn create_physical_fun( internal_err!("Unsupported data type {other:?} for function initcap") } }), + BuiltinScalarFunction::InStr => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::instr::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::instr::)(args) + } + other => internal_err!("Unsupported data type {other:?} for function instr"), + }), BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); @@ -765,6 +774,17 @@ pub fn create_physical_fun( internal_err!("Unsupported data type {other:?} for function starts_with") } }), + BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ends_with::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ends_with::)(args) + } + other => { + internal_err!("Unsupported data type {other:?} for function ends_with") + } + }), BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -1379,6 +1399,46 @@ mod tests { Utf8, StringArray ); + test_function!( + InStr, + &[lit("abc"), lit("b")], + Ok(Some(2)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("abc"), lit("c")], + Ok(Some(3)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("abc"), lit("d")], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("abc"), lit("")], + Ok(Some(1)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("Helloworld"), lit("world")], + Ok(Some(6)), + i32, + Int32, + Int32Array + ); #[cfg(feature = "unicode_expressions")] test_function!( Left, @@ -2497,6 +2557,38 @@ mod tests { Boolean, BooleanArray ); + test_function!( + EndsWith, + &[lit("alphabet"), lit("alph"),], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWith, + &[lit("alphabet"), lit("bet"),], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWith, + &[lit(ScalarValue::Utf8(None)), lit("alph"),], + Ok(None), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWith, + &[lit("alphabet"), lit(ScalarValue::Utf8(None)),], + Ok(None), + bool, + Boolean, + BooleanArray + ); #[cfg(feature = "unicode_expressions")] test_function!( Strpos, diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 7d9fecf614075..5ff4389776230 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -296,6 +296,50 @@ pub fn initcap(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Returns the position of the first occurrence of substring in string. +/// The position is counted from 1. If the substring is not found, returns 0. +/// For example, instr('Helloworld', 'world') = 6. +pub fn instr(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let substr_array = as_generic_string_array::(&args[1])?; + + match args[0].data_type() { + DataType::Utf8 => { + let result = string_array + .iter() + .zip(substr_array.iter()) + .map(|(string, substr)| match (string, substr) { + (Some(string), Some(substr)) => { + string.find(substr).map_or(0, |index| (index + 1) as i32) + } + _ => 0, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = string_array + .iter() + .zip(substr_array.iter()) + .map(|(string, substr)| match (string, substr) { + (Some(string), Some(substr)) => { + string.find(substr).map_or(0, |index| (index + 1) as i64) + } + _ => 0, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "instr was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + /// Converts the string to all lower case. /// lower('TOM') = 'tom' pub fn lower(args: &[ColumnarValue]) -> Result { @@ -476,6 +520,24 @@ pub fn starts_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Returns true if string ends with suffix. +/// ends_with('alphabet', 'abet') = 't' +pub fn ends_with(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let suffix_array = as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(suffix_array.iter()) + .map(|(string, suffix)| match (string, suffix) { + (Some(string), Some(suffix)) => Some(string.ends_with(suffix)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d5d86b2179faf..f3be86d7e89dd 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22234,7 +22234,9 @@ impl serde::Serialize for ScalarFunction { Self::ConcatWithSeparator => "ConcatWithSeparator", Self::DatePart => "DatePart", Self::DateTrunc => "DateTrunc", + Self::EndsWith => "EndsWith", Self::InitCap => "InitCap", + Self::InStr => "InStr", Self::Left => "Left", Self::Lpad => "Lpad", Self::Lower => "Lower", @@ -22377,6 +22379,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "DatePart", "DateTrunc", "InitCap", + "InStr", "Left", "Lpad", "Lower", @@ -22547,7 +22550,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "DatePart" => Ok(ScalarFunction::DatePart), "DateTrunc" => Ok(ScalarFunction::DateTrunc), + "EndsWith" => Ok(ScalarFunction::EndsWith), "InitCap" => Ok(ScalarFunction::InitCap), + "InStr" => Ok(ScalarFunction::InStr), "Left" => Ok(ScalarFunction::Left), "Lpad" => Ok(ScalarFunction::Lpad), "Lower" => Ok(ScalarFunction::Lower), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7e262e620fa7d..4650f0487db81 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2759,6 +2759,8 @@ pub enum ScalarFunction { ArraySort = 128, ArrayDistinct = 129, ArrayResize = 130, + EndsWith = 131, + InStr = 132, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2797,7 +2799,9 @@ impl ScalarFunction { ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::DatePart => "DatePart", ScalarFunction::DateTrunc => "DateTrunc", + ScalarFunction::EndsWith => "EndsWith", ScalarFunction::InitCap => "InitCap", + ScalarFunction::InStr => "InStr", ScalarFunction::Left => "Left", ScalarFunction::Lpad => "Lpad", ScalarFunction::Lower => "Lower", @@ -2933,7 +2937,9 @@ impl ScalarFunction { "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "DatePart" => Some(Self::DatePart), "DateTrunc" => Some(Self::DateTrunc), + "EndsWith" => Some(Self::EndsWith), "InitCap" => Some(Self::InitCap), + "InStr" => Some(Self::InStr), "Left" => Some(Self::Left), "Lpad" => Some(Self::Lpad), "Lower" => Some(Self::Lower), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index c11599412d94b..f4e5243ea48c9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -51,10 +51,10 @@ use datafusion_expr::{ array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, - date_trunc, decode, degrees, digest, encode, exp, + date_trunc, decode, degrees, digest, encode, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, - lcm, left, levenshtein, ln, log, log10, log2, + factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, instr, isnan, + iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -532,7 +532,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::Chr => Self::Chr, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, + ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, + ScalarFunction::InStr => Self::InStr, ScalarFunction::Left => Self::Left, ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, @@ -1591,6 +1593,10 @@ pub fn parse_expr( } ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry)?)), ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0], registry)?)), + ScalarFunction::InStr => Ok(instr( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Gcd => Ok(gcd( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1670,6 +1676,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::EndsWith => Ok(ends_with( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Strpos => Ok(strpos( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ec9b886c1f222..d97801012b24e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1520,7 +1520,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::Chr => Self::Chr, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, + BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, + BuiltinScalarFunction::InStr => Self::InStr, BuiltinScalarFunction::Left => Self::Left, BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 1903088b0748d..39d8d3933e846 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -647,6 +647,21 @@ SELECT initcap(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) ---- Foo +query I +SELECT instr('foobarbar', 'bar') +---- +4 + +query I +SELECT instr('foobarbar', 'aa') +---- +0 + +query I +SELECT instr('foobarbar', '') +---- +1 + query T SELECT lower('FOObar') ---- @@ -727,6 +742,26 @@ SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) ---- bar +query B +SELECT starts_with('foobar', 'foo') +---- +true + +query B +SELECT starts_with('foobar', 'bar') +---- +false + +query B +SELECT ends_with('foobar', 'bar') +---- +true + +query B +SELECT ends_with('foobar', 'foo') +---- +false + query T SELECT trim(' foo ') ---- diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 9dd008f8fc44c..1aa7f6de3864c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -613,7 +613,9 @@ nullif(expression1, expression2) - [concat](#concat) - [concat_ws](#concat_ws) - [chr](#chr) +- [ends_with](#ends_with) - [initcap](#initcap) +- [instr](#instr) - [left](#left) - [length](#length) - [lower](#lower) @@ -756,6 +758,20 @@ chr(expression) **Related functions**: [ascii](#ascii) +### `ends_with` + +Tests if a string ends with a substring. + +``` +ends_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to test. + Can be a constant, column, or function, and any combination of string operators. +- **substr**: Substring to test for. + ### `initcap` Capitalizes the first character in each word in the input string. @@ -774,6 +790,22 @@ initcap(str) [lower](#lower), [upper](#upper) +### `instr` + +Returns the location where substr first appeared in str (counting from 1). +If substr does not appear in str, return 0. + +``` +instr(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. + Can be a constant, column, or function, and any combination of string operators. +- **substr**: Substring expression to search for. + Can be a constant, column, or function, and any combination of string operators. + ### `left` Returns a specified number of characters from the left side of a string.