From e624204741b75d8858734456263e5249dc3cceeb Mon Sep 17 00:00:00 2001 From: Gosh Arzumanyan Date: Mon, 31 Oct 2022 13:53:31 -0700 Subject: [PATCH] Adding strpos/strrpos Presto functions based on Simple Functions API (#2903) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/2903 Based on the usage and benchmark analysis: 1. Move strpos Presto function implementation to use Simple Function API 2. Add strrpos Presto function based on Simple Function API 3. Add test coverage Reviewed By: mbasmanova Differential Revision: D40527007 fbshipit-source-id: 70ad282fa8d18683e1e36db004acf38c1207fee2 --- velox/docs/functions/string.rst | 11 ++ velox/functions/lib/string/StringCore.h | 36 ++++- velox/functions/lib/string/StringImpl.h | 29 ++-- velox/functions/prestosql/StringFunctions.cpp | 130 ------------------ velox/functions/prestosql/StringFunctions.h | 48 +++++++ .../StringFunctionsRegistration.cpp | 8 +- .../prestosql/tests/StringFunctionsTest.cpp | 89 ++++++++++++ 7 files changed, 207 insertions(+), 144 deletions(-) diff --git a/velox/docs/functions/string.rst b/velox/docs/functions/string.rst index 6a23d36f76a8..30f7efe6e195 100644 --- a/velox/docs/functions/string.rst +++ b/velox/docs/functions/string.rst @@ -107,6 +107,17 @@ String Functions ``instance`` must be a positive number. Positions start with ``1``. If not found, ``0`` is returned. +.. function:: strrpos(string, substring) -> bigint + + Returns the starting position of the last instance of ``substring`` in + ``string``. Positions start with ``1``. If not found, ``0`` is returned. + +.. function:: strrpos(string, substring, instance) -> bigint + + Returns the position of the N-th ``instance`` of ``substring`` in ``string`` starting from the end of the string. + ``instance`` must be a positive number. + Positions start with ``1``. If not found, ``0`` is returned. + .. function:: substr(string, start) -> varchar Returns the rest of ``string`` from the starting position ``start``. diff --git a/velox/functions/lib/string/StringCore.h b/velox/functions/lib/string/StringCore.h index 53a31903edbe..c0778f8567ef 100644 --- a/velox/functions/lib/string/StringCore.h +++ b/velox/functions/lib/string/StringCore.h @@ -231,7 +231,7 @@ lengthUnicode(const char* inputBuffer, size_t bufferLength) { /// substring and then computing the length of substring[0, byteIndex). This is /// safe because in UTF8 a char can not be a subset of another char (in bytes /// representation). -static int64_t findNthInstanceByteIndex( +static int64_t findNthInstanceByteIndexFromStart( const std::string_view& string, const std::string_view subString, const size_t instance = 1, @@ -254,10 +254,39 @@ static int64_t findNthInstanceByteIndex( } // Find next occurrence - return findNthInstanceByteIndex( + return findNthInstanceByteIndexFromStart( string, subString, instance - 1, byteIndex + subString.size()); } +/// Returns the start byte index of the Nth instance of subString in +/// string from the end. Search starts from endPosition. Positions start with 0. +/// If not found, -1 is returned. +inline int64_t findNthInstanceByteIndexFromEnd( + const std::string_view string, + const std::string_view subString, + const size_t instance = 1) { + assert(instance > 0); + + if (subString.empty()) { + return 0; + } + + size_t foundCnt = 0; + size_t index = string.size(); + do { + if (index == 0) { + return -1; + } + + index = string.rfind(subString, index - 1); + if (index == std::string_view::npos) { + return -1; + } + ++foundCnt; + } while (foundCnt < instance); + return index; +} + /// Replace replaced with replacement in inputString and write results in /// outputString. If inPlace=true inputString and outputString are assumed to /// tbe the same. When replaced is empty, replacement is added before and after @@ -281,7 +310,8 @@ inline static size_t replace( bool doCopyUnreplaced = !inPlace || (replaced.size() != replacement.size()); auto findNextReplaced = [&]() { - return findNthInstanceByteIndex(inputString, replaced, 1, readPosition); + return findNthInstanceByteIndexFromStart( + inputString, replaced, 1, readPosition); }; auto writeUnchanged = [&](ssize_t size) { diff --git a/velox/functions/lib/string/StringImpl.h b/velox/functions/lib/string/StringImpl.h index f2dd7f7e5420..79ca8552e4cf 100644 --- a/velox/functions/lib/string/StringImpl.h +++ b/velox/functions/lib/string/StringImpl.h @@ -148,22 +148,30 @@ FOLLY_ALWAYS_INLINE int32_t charToCodePoint(const T& inputString) { return codePoint; } -/// Returns the starting position in characters of the Nth instance of the -/// substring in string. Positions start with 1. If not found, 0 is returned. If -/// subString is empty result is 1. -template +/// Returns the starting position in characters of the Nth instance(counting +/// from the left if lpos==true and from the end otherwise) of the substring in +/// string. Positions start with 1. If not found, 0 is returned. If subString is +/// empty result is 1. +template FOLLY_ALWAYS_INLINE int64_t stringPosition(const T& string, const T& subString, int64_t instance = 0) { + VELOX_USER_CHECK_GT(instance, 0, "'instance' must be a positive number"); if (subString.size() == 0) { return 1; } - VELOX_USER_CHECK_GT(instance, 0, "'instance' must be a positive number"); - - auto byteIndex = findNthInstanceByteIndex( - std::string_view(string.data(), string.size()), - std::string_view(subString.data(), subString.size()), - instance); + int64_t byteIndex = -1; + if constexpr (lpos) { + byteIndex = findNthInstanceByteIndexFromStart( + std::string_view(string.data(), string.size()), + std::string_view(subString.data(), subString.size()), + instance); + } else { + byteIndex = findNthInstanceByteIndexFromEnd( + std::string_view(string.data(), string.size()), + std::string_view(subString.data(), subString.size()), + instance); + } if (byteIndex == -1) { return 0; @@ -700,4 +708,5 @@ FOLLY_ALWAYS_INLINE void pad( padString.data(), padPrefixByteLength); } + } // namespace facebook::velox::functions::stringImpl diff --git a/velox/functions/prestosql/StringFunctions.cpp b/velox/functions/prestosql/StringFunctions.cpp index b4e2aa584928..81c8eadc4e87 100644 --- a/velox/functions/prestosql/StringFunctions.cpp +++ b/velox/functions/prestosql/StringFunctions.cpp @@ -278,131 +278,6 @@ class ConcatFunction : public exec::VectorFunction { std::vector constantStringViews_; }; -/** - * strpos(string, substring) → bigint - * Returns the starting position of the first instance of substring in string. - * Positions start with 1. If not found, 0 is returned. - * - * strpos(string, substring, instance) → bigint - * Returns the position of the N-th instance of substring in string. instance - * must be a positive number. Positions start with 1. If not found, 0 is - * returned. - **/ -class StringPosition : public exec::VectorFunction { - private: - /// A function that can be wrapped with ascii mode - template - struct ApplyInternal { - template < - typename StringReader, - typename SubStringReader, - typename InstanceReader> - static void apply( - StringReader stringReader, - SubStringReader subStringReader, - InstanceReader instanceReader, - const SelectivityVector& rows, - FlatVector* resultFlatVector) { - rows.applyToSelected([&](int row) { - auto result = stringImpl::stringPosition( - stringReader(row), subStringReader(row), instanceReader(row)); - resultFlatVector->set(row, result); - }); - } - }; - - public: - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& /* outputType */, - exec::EvalCtx& context, - VectorPtr& result) const override { - exec::DecodedArgs decodedArgs(rows, args, context); - auto decodedStringInput = decodedArgs.at(0); - auto decodedSubStringInput = decodedArgs.at(1); - - auto stringArgStringEncoding = isAscii(args.at(0).get(), rows); - context.ensureWritable(rows, BIGINT(), result); - - auto* resultFlatVector = result->as>(); - - auto stringReader = [&](const vector_size_t row) { - return decodedStringInput->valueAt(row); - }; - - auto substringReader = [&](const vector_size_t row) { - return decodedSubStringInput->valueAt(row); - }; - - // If there's no "instance" parameter. - if (args.size() <= 2) { - StringEncodingTemplateWrapper::apply( - stringArgStringEncoding, - stringReader, - substringReader, - [](const vector_size_t) { return 1L; }, - rows, - resultFlatVector); - } - // If there's an "instance" parameter, check if it's BIGINT or INTEGER. - else { - auto decodedInstanceInput = decodedArgs.at(2); - - if (args[2]->typeKind() == TypeKind::BIGINT) { - auto instanceReader = [&](const vector_size_t row) { - return decodedInstanceInput->valueAt(row); - }; - StringEncodingTemplateWrapper::apply( - stringArgStringEncoding, - stringReader, - substringReader, - instanceReader, - rows, - resultFlatVector); - } else if (args[2]->typeKind() == TypeKind::INTEGER) { - auto instanceReader = [&](const vector_size_t row) { - return decodedInstanceInput->valueAt(row); - }; - StringEncodingTemplateWrapper::apply( - stringArgStringEncoding, - stringReader, - substringReader, - instanceReader, - rows, - resultFlatVector); - } else { - VELOX_UNREACHABLE(); - } - } - } - - static std::vector> signatures() { - return { - // varchar, varchar -> bigint - exec::FunctionSignatureBuilder() - .returnType("bigint") - .argumentType("varchar") - .argumentType("varchar") - .build(), - // varchar, varchar, integer -> bigint - exec::FunctionSignatureBuilder() - .returnType("bigint") - .argumentType("varchar") - .argumentType("varchar") - .argumentType("integer") - .build(), - // varchar, varchar, bigint -> bigint - exec::FunctionSignatureBuilder() - .returnType("bigint") - .argumentType("varchar") - .argumentType("varchar") - .argumentType("bigint") - .build(), - }; - } -}; - /** * replace(string, search) → varchar * Removes all instances of search from string. @@ -577,11 +452,6 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION_WITH_METADATA( return std::make_unique(name, inputs); }); -VELOX_DECLARE_VECTOR_FUNCTION( - udf_strpos, - StringPosition::signatures(), - std::make_unique()); - VELOX_DECLARE_VECTOR_FUNCTION( udf_replace, Replace::signatures(), diff --git a/velox/functions/prestosql/StringFunctions.h b/velox/functions/prestosql/StringFunctions.h index b323a3804a24..75d8b233c5fe 100644 --- a/velox/functions/prestosql/StringFunctions.h +++ b/velox/functions/prestosql/StringFunctions.h @@ -15,6 +15,7 @@ */ #pragma once +#include #define XXH_INLINE_ALL #include @@ -349,4 +350,51 @@ struct LPadFunction : public PadFunctionBase {}; template struct RPadFunction : public PadFunctionBase {}; +/// strpos and strrpos functions +/// strpos(string, substring) → bigint +/// Returns the starting position of the first instance of substring in +/// string. Positions start with 1. If not found, 0 is returned. +/// strpos(string, substring, instance) → bigint +/// Returns the position of the N-th instance of substring in string. +/// instance must be a positive number. Positions start with 1. If not +/// found, 0 is returned. +/// strrpos(string, substring) → bigint +/// Returns the starting position of the first instance of substring in +/// string counting from the end. Positions start with 1. If not found, 0 is +/// returned. +/// strrpos(string, substring, instance) → bigint +/// Returns the position of the N-th instance of substring in string +/// counting from the end. Instance must be a positive number. Positions +/// start with 1. If not found, 0 is returned. +template +struct StrPosFunctionBase { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& string, + const arg_type& subString, + const arg_type& instance = 1) { + result = stringImpl::stringPosition( + string, subString, instance); + return true; + } + + FOLLY_ALWAYS_INLINE bool callAscii( + out_type& result, + const arg_type& string, + const arg_type& subString, + const arg_type& instance = 1) { + result = stringImpl::stringPosition( + string, subString, instance); + return true; + } +}; + +template +struct StrLPosFunction : public StrPosFunctionBase {}; + +template +struct StrRPosFunction : public StrPosFunctionBase {}; + } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp index 513d7001422a..bb3f664d602e 100644 --- a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp @@ -82,7 +82,6 @@ void registerStringFunctions() { VELOX_REGISTER_VECTOR_FUNCTION(udf_upper, "upper"); VELOX_REGISTER_VECTOR_FUNCTION(udf_split, "split"); VELOX_REGISTER_VECTOR_FUNCTION(udf_concat, "concat"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_strpos, "strpos"); VELOX_REGISTER_VECTOR_FUNCTION(udf_replace, "replace"); VELOX_REGISTER_VECTOR_FUNCTION(udf_reverse, "reverse"); VELOX_REGISTER_VECTOR_FUNCTION(udf_to_utf8, "to_utf8"); @@ -94,5 +93,12 @@ void registerStringFunctions() { "regexp_extract_all", re2ExtractAllSignatures(), makeRe2ExtractAll); exec::registerStatefulVectorFunction( "regexp_like", re2SearchSignatures(), makeRe2Search); + + registerFunction({"strpos"}); + registerFunction( + {"strpos"}); + registerFunction({"strrpos"}); + registerFunction( + {"strrpos"}); } } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/tests/StringFunctionsTest.cpp b/velox/functions/prestosql/tests/StringFunctionsTest.cpp index 6bf68da3c7a7..668549e50447 100644 --- a/velox/functions/prestosql/tests/StringFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/StringFunctionsTest.cpp @@ -299,6 +299,12 @@ class StringFunctionsTest : public FunctionBaseTest { const std::vector>& stringEncodings, bool withInstanceArgument); + template + void testStringPositionFromEndAllFlatVector( + const strpos_input_test_t& tests, + const std::vector>& stringEncodings, + bool withInstanceArgument); + void testChrFlatVector( const std::vector>& tests); @@ -969,6 +975,89 @@ TEST_F(StringFunctionsTest, stringPosition) { } } +// Test strpos function +template +void StringFunctionsTest::testStringPositionFromEndAllFlatVector( + const strpos_input_test_t& tests, + const std::vector>& asciiEncodings, + bool withInstanceArgument) { + auto stringVector = makeFlatVector(tests.size()); + auto subStringVector = makeFlatVector(tests.size()); + auto instanceVector = + withInstanceArgument ? makeFlatVector(tests.size()) : nullptr; + + for (int i = 0; i < tests.size(); i++) { + stringVector->set(i, StringView(std::get<0>(tests[i].first))); + subStringVector->set(i, StringView(std::get<1>(tests[i].first))); + if (instanceVector) { + instanceVector->set(i, std::get<2>(tests[i].first)); + } + } + + if (asciiEncodings[0].has_value()) { + stringVector->setAllIsAscii(asciiEncodings[0].value()); + } + if (asciiEncodings[1].has_value()) { + subStringVector->setAllIsAscii(asciiEncodings[1].value()); + } + + FlatVectorPtr result; + if (withInstanceArgument) { + result = evaluate>( + "strrpos(c0, c1,c2)", + makeRowVector({stringVector, subStringVector, instanceVector})); + } else { + result = evaluate>( + "strrpos(c0, c1)", makeRowVector({stringVector, subStringVector})); + } + + for (int32_t i = 0; i < tests.size(); ++i) { + ASSERT_EQ(result->valueAt(i), tests[i].second); + } +} + +TEST_F(StringFunctionsTest, stringPositionFromEnd) { + strpos_input_test_t testsAscii = { + {{"high", "ig", -1}, {2}}, + {{"high", "igx", -1}, {0}}, + {{"high", "h", -1}, {4}}, + {{"", "h", -1}, {0}}, + {{"high", "", -1}, {1}}, + {{"", "", -1}, {1}}, + }; + + strpos_input_test_t testsAsciiWithPosition = { + {{"high", "h", 2}, 1}, + {{"high", "h", 10}, 0}, + {{"high", "", 2}, {1}}, + {{"", "", 2}, {1}}, + }; + + strpos_input_test_t testsUnicodeWithPosition = { + {{"\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "\u7231", 1}, 4}, + {{"\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "\u5E0C\u671B", 1}, 6}, + }; + + // We dont have to try all encoding combinations here since there is a test + // that test the encoding resolution but we want to to have a test for each + // possible resolution + testStringPositionFromEndAllFlatVector( + testsAscii, {true, true}, false); + + // Try instance parameter using BIGINT and INTEGER. + testStringPositionFromEndAllFlatVector( + testsAsciiWithPosition, {false, false}, true); + testStringPositionFromEndAllFlatVector( + testsAsciiWithPosition, {false, false}, true); + + // Test constant vectors + auto rows = makeRowVector(makeRowType({BIGINT()}), 10); + auto result = evaluate>("strrpos('high', 'ig')", rows); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(result->valueAt(i), 2); + } +} + void StringFunctionsTest::testChrFlatVector( const std::vector>& tests) { auto codePoints = makeFlatVector(tests.size());