From c12534290f3d4caf7059e92518a7e7afad84908a Mon Sep 17 00:00:00 2001 From: Deepak Majeti Date: Wed, 26 Apr 2023 09:45:37 -0700 Subject: [PATCH] Implement decimal comparisons as vector functions (#4638) Summary: Decimal Simple Functions require separate signatures/registration for ShortDecimal and LongDecimal types. We want to view Decimal Types as a single entity/type. Simple Functions break this view. To achieve the single entity view, all Decimal functions must be implemented as Vector Functions. The Decimal comparison functions are currently implemented as Simple Functions. The scope of this PR is to implement them as Vector Functions. Pull Request resolved: https://github.com/facebookincubator/velox/pull/4638 Reviewed By: mbasmanova Differential Revision: D45193675 Pulled By: Yuhta fbshipit-source-id: 5f78d04464ec910577ca8c353493fa66cb6890ab --- velox/core/SimpleFunctionMetadata.h | 20 +---- velox/functions/prestosql/CMakeLists.txt | 2 +- velox/functions/prestosql/Comparisons.cpp | 84 +++++++++++------- ...thmetic.cpp => DecimalVectorFunctions.cpp} | 88 ++++++++++++++++++- .../ComparisonFunctionsRegistration.cpp | 41 +-------- .../prestosql/tests/ComparisonsTest.cpp | 9 +- velox/type/Type.cpp | 9 -- velox/type/Type.h | 2 - 8 files changed, 152 insertions(+), 103 deletions(-) rename velox/functions/prestosql/{DecimalArithmetic.cpp => DecimalVectorFunctions.cpp} (85%) diff --git a/velox/core/SimpleFunctionMetadata.h b/velox/core/SimpleFunctionMetadata.h index 2778e8b17bc3..ac18fa0834db 100644 --- a/velox/core/SimpleFunctionMetadata.h +++ b/velox/core/SimpleFunctionMetadata.h @@ -27,9 +27,6 @@ #include "velox/type/Type.h" #include "velox/type/Variant.h" -inline constexpr char kPrecisionVariable[] = "a_precision"; -inline constexpr char kScaleVariable[] = "a_scale"; - namespace facebook::velox::core { // Most UDFs are deterministic, hence this default value. @@ -194,14 +191,8 @@ struct TypeAnalysis { SimpleTypeTrait::isPrimitiveType || SimpleTypeTrait::typeKind == TypeKind::OPAQUE); results.stats.concreteCount++; - if (isDecimalKind(SimpleTypeTrait::typeKind)) { - results.out << detail::strToLowerCopy( - std::string(SimpleTypeTrait::name)) - << "(" << kPrecisionVariable << "," << kScaleVariable << ")"; - } else { - results.out << detail::strToLowerCopy( - std::string(SimpleTypeTrait::name)); - } + results.out << detail::strToLowerCopy( + std::string(SimpleTypeTrait::name)); } }; @@ -443,15 +434,8 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { builder.returnType(analysis.outputType); - bool isDecimalArg = false; for (const auto& arg : analysis.argsTypes) { builder.argumentType(arg); - isDecimalArg |= isDecimalTypeSignature(arg); - } - - if (isDecimalArg) { - builder.integerVariable(kPrecisionVariable); - builder.integerVariable(kScaleVariable); } for (const auto& variable : analysis.variables) { diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index da946bfa3f86..4e6864b5654c 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -28,7 +28,7 @@ add_library( ArraySort.cpp ArraySum.cpp Comparisons.cpp - DecimalArithmetic.cpp + DecimalVectorFunctions.cpp ElementAt.cpp FilterFunctions.cpp FromUnixTime.cpp diff --git a/velox/functions/prestosql/Comparisons.cpp b/velox/functions/prestosql/Comparisons.cpp index 1cba6fbd5ed8..b52ccb666f76 100644 --- a/velox/functions/prestosql/Comparisons.cpp +++ b/velox/functions/prestosql/Comparisons.cpp @@ -97,9 +97,10 @@ struct SimdComparator { template < TypeKind kind, typename std::enable_if_t< - xsimd::has_simd_register< - typename TypeTraits::NativeType>::value && - kind != TypeKind::BOOLEAN, + (xsimd::has_simd_register< + typename TypeTraits::NativeType>::value && + kind != TypeKind::BOOLEAN) || + kind == TypeKind::SHORT_DECIMAL || kind == TypeKind::LONG_DECIMAL, int> = 0> void applyComparison( const SelectivityVector& rows, @@ -112,11 +113,16 @@ struct SimdComparator { auto resultVector = result->asUnchecked>(); auto rawResult = resultVector->mutableRawValues(); + // UnscaledShortDecimal and UnscaledLongDecimal will soon be replaced with + // int64_t and int128_t respectively. This change will be removed then. + constexpr bool isDecimal = + (std::is_same_v || + std::is_same_v); auto isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) && (rhs.isConstantEncoding() || rhs.isFlatEncoding()) && rows.isAllSelected(); - if (!isSimdizable) { + if (!isSimdizable || isDecimal) { exec::LocalDecodedVector lhsDecoded(context, lhs, rows); exec::LocalDecodedVector rhsDecoded(context, rhs, rows); @@ -128,38 +134,40 @@ struct SimdComparator { }); return; } + if constexpr (!isDecimal) { + if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) { + auto l = lhs.asUnchecked>()->valueAt(0); + auto r = rhs.asUnchecked>()->valueAt(0); + applySimdComparison( + rows.begin(), rows.end(), &l, &r, rawResult); + } else if (lhs.isConstantEncoding()) { + auto l = lhs.asUnchecked>()->valueAt(0); + auto rawRhs = rhs.asUnchecked>()->rawValues(); + applySimdComparison( + rows.begin(), rows.end(), &l, rawRhs, rawResult); + } else if (rhs.isConstantEncoding()) { + auto rawLhs = lhs.asUnchecked>()->rawValues(); + auto r = rhs.asUnchecked>()->valueAt(0); + applySimdComparison( + rows.begin(), rows.end(), rawLhs, &r, rawResult); + } else { + auto rawLhs = lhs.asUnchecked>()->rawValues(); + auto rawRhs = rhs.asUnchecked>()->rawValues(); + applySimdComparison( + rows.begin(), rows.end(), rawLhs, rawRhs, rawResult); + } - if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) { - auto l = lhs.asUnchecked>()->valueAt(0); - auto r = rhs.asUnchecked>()->valueAt(0); - applySimdComparison( - rows.begin(), rows.end(), &l, &r, rawResult); - } else if (lhs.isConstantEncoding()) { - auto l = lhs.asUnchecked>()->valueAt(0); - auto rawRhs = rhs.asUnchecked>()->rawValues(); - applySimdComparison( - rows.begin(), rows.end(), &l, rawRhs, rawResult); - } else if (rhs.isConstantEncoding()) { - auto rawLhs = lhs.asUnchecked>()->rawValues(); - auto r = rhs.asUnchecked>()->valueAt(0); - applySimdComparison( - rows.begin(), rows.end(), rawLhs, &r, rawResult); - } else { - auto rawLhs = lhs.asUnchecked>()->rawValues(); - auto rawRhs = rhs.asUnchecked>()->rawValues(); - applySimdComparison( - rows.begin(), rows.end(), rawLhs, rawRhs, rawResult); + resultVector->clearNulls(rows); } - - resultVector->clearNulls(rows); } template < TypeKind kind, typename std::enable_if_t< - !xsimd::has_simd_register< - typename TypeTraits::NativeType>::value || - kind == TypeKind::BOOLEAN, + (!xsimd::has_simd_register< + typename TypeTraits::NativeType>::value || + kind == TypeKind::BOOLEAN) && + kind != TypeKind::SHORT_DECIMAL && kind != TypeKind::LONG_DECIMAL, int> = 0> void applyComparison( const SelectivityVector& /* rows */, @@ -187,6 +195,16 @@ class ComparisonSimdFunction : public exec::VectorFunction { context.ensureWritable(rows, outputType, result); auto comparator = SimdComparator{}; + if (args[0]->type()->isShortDecimal()) { + comparator.template applyComparison( + rows, *args[0], *args[1], context, result); + return; + } else if (args[0]->type()->isLongDecimal()) { + comparator.template applyComparison( + rows, *args[0], *args[1], context, result); + return; + } + VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( comparator.template applyComparison, args[0]->typeKind(), @@ -208,7 +226,13 @@ class ComparisonSimdFunction : public exec::VectorFunction { .argumentType(inputType) .build()); } - + signatures.push_back(exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .returnType("boolean") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .build()); return signatures; } diff --git a/velox/functions/prestosql/DecimalArithmetic.cpp b/velox/functions/prestosql/DecimalVectorFunctions.cpp similarity index 85% rename from velox/functions/prestosql/DecimalArithmetic.cpp rename to velox/functions/prestosql/DecimalVectorFunctions.cpp index 5246188989cb..b1af12246f79 100644 --- a/velox/functions/prestosql/DecimalArithmetic.cpp +++ b/velox/functions/prestosql/DecimalVectorFunctions.cpp @@ -146,10 +146,59 @@ class DecimalUnaryBaseFunction : public exec::VectorFunction { return result->asUnchecked>()->mutableRawValues(); } - private: const uint8_t aRescale_; }; +template +class DecimalBetweenFunction : public exec::VectorFunction { + public: + DecimalBetweenFunction() {} + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& resultType, + exec::EvalCtx& context, + VectorPtr& result) const override { + prepareResults(rows, resultType, context, result); + // Second and third arguments must always be constant. + VELOX_CHECK(args[1]->isConstantEncoding() && args[2]->isConstantEncoding()); + auto constantB = args[1]->asUnchecked>()->valueAt(0); + auto constantC = args[2]->asUnchecked>()->valueAt(0); + if (args[0]->isFlatEncoding()) { + // Fast path if first argument is flat. + auto flatA = args[0]->asUnchecked>(); + auto rawA = flatA->mutableRawValues(); + context.applyToSelectedNoThrow(rows, [&](auto row) { + result->asUnchecked>()->set( + row, + rawA[row].unscaledValue() >= constantB.unscaledValue() && + rawA[row].unscaledValue() <= constantC.unscaledValue()); + }); + } else { + // Path if first argument is encoded. + exec::DecodedArgs decodedArgs(rows, args, context); + auto a = decodedArgs.at(0); + context.applyToSelectedNoThrow(rows, [&](auto row) { + auto value = a->valueAt(row); + result->asUnchecked>()->set( + row, + value.unscaledValue() >= constantB.unscaledValue() && + value.unscaledValue() <= constantC.unscaledValue()); + }); + } + } + + private: + void prepareResults( + const SelectivityVector& rows, + const TypePtr& resultType, + exec::EvalCtx& context, + VectorPtr& result) const { + context.ensureWritable(rows, resultType, result); + result->clearNulls(rows); + } +}; + class Addition { public: template @@ -431,6 +480,18 @@ decimalAbsNegateSignature() { .build()}; } +std::vector> +decimalBetweenSignature() { + return {exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .returnType("BOOLEAN") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .build()}; +} + template std::shared_ptr createDecimalUnary( const std::string& /*name*/, @@ -516,6 +577,26 @@ std::shared_ptr createDecimalFunction( } VELOX_UNSUPPORTED(); } + +std::shared_ptr createDecimalBetweenFunction( + const std::string& name, + const std::vector& inputArgs) { + auto aType = inputArgs[0].type; + auto bType = inputArgs[1].type; + auto cType = inputArgs[2].type; + if (aType->kind() == TypeKind::SHORT_DECIMAL) { + VELOX_CHECK(bType->kind() == TypeKind::SHORT_DECIMAL); + VELOX_CHECK(cType->kind() == TypeKind::SHORT_DECIMAL); + // Arguments are short decimals. + return std::make_shared>(); + } else { + VELOX_CHECK(bType->kind() == TypeKind::LONG_DECIMAL); + VELOX_CHECK(cType->kind() == TypeKind::LONG_DECIMAL); + // Arguments are long decimals. + return std::make_shared>(); + } + VELOX_UNSUPPORTED(); +} }; // namespace VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( @@ -552,4 +633,9 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_negate, decimalAbsNegateSignature(), createDecimalUnary); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_between, + decimalBetweenSignature(), + createDecimalBetweenFunction); }; // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp index c95468442ed8..3e43371646d1 100644 --- a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp @@ -58,47 +58,8 @@ void registerComparisonFunctions(const std::string& prefix) { {prefix + "between"}); registerFunction( {prefix + "between"}); - registerFunction< - BetweenFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "between"}); - registerFunction< - BetweenFunction, - bool, - UnscaledLongDecimal, - UnscaledLongDecimal, - UnscaledLongDecimal>({prefix + "between"}); - registerFunction< - GtFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "gt"}); - registerFunction( - {prefix + "gt"}); - registerFunction< - LtFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "lt"}); - registerFunction( - {prefix + "lt"}); - registerFunction< - GteFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "gte"}); - registerFunction( - {prefix + "gte"}); - registerFunction< - LteFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "lte"}); - registerFunction( - {prefix + "lte"}); + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_between, prefix + "between"); } } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/tests/ComparisonsTest.cpp b/velox/functions/prestosql/tests/ComparisonsTest.cpp index 682c109d5629..d1b8c14f9063 100644 --- a/velox/functions/prestosql/tests/ComparisonsTest.cpp +++ b/velox/functions/prestosql/tests/ComparisonsTest.cpp @@ -107,7 +107,7 @@ TEST_F(ComparisonsTest, betweenDecimal) { VectorPtr input, VectorPtr expectedResult) { auto actual = evaluate>(exprStr, makeRowVector({input})); - test::assertEqualVectors(actual, expectedResult); + test::assertEqualVectors(expectedResult, actual); }; auto shortFlat = makeNullableShortDecimalFlatVector( @@ -120,6 +120,11 @@ TEST_F(ComparisonsTest, betweenDecimal) { auto longFlat = makeNullableLongDecimalFlatVector( {100, 250, 300, 500, std::nullopt}, DECIMAL(20, 2)); + runAndCompare( + "c0 between cast(2.00 as DECIMAL(20, 2)) and cast(3.00 as DECIMAL(20, 2))", + longFlat, + expectedResult); + // Comparing LONG_DECIMAL and SHORT_DECIMAL must throw error. VELOX_ASSERT_THROW( runAndCompare("c0 between 2.00 and 3.00", longFlat, expectedResult), @@ -159,7 +164,7 @@ TEST_F(ComparisonsTest, gtLtDecimal) { std::vector& inputs, VectorPtr expectedResult) { auto actual = evaluate>(expr, makeRowVector(inputs)); - test::assertEqualVectors(actual, expectedResult); + test::assertEqualVectors(expectedResult, actual); }; // Short Decimals test. diff --git a/velox/type/Type.cpp b/velox/type/Type.cpp index b2231af0b04a..989aee4333e7 100644 --- a/velox/type/Type.cpp +++ b/velox/type/Type.cpp @@ -47,15 +47,6 @@ bool isDecimalName(const std::string& typeName) { typeNameUpper == TypeTraits::name); } -bool isDecimalTypeSignature(const std::string& arg) { - auto upper = boost::algorithm::to_upper_copy(arg); - return ( - upper.find(TypeTraits::name) != - std::string::npos || - upper.find(TypeTraits::name) != - std::string::npos); -} - // Static variable intialization is not thread safe for non // constant-initialization, but scoped static initialization is thread safe. const std::unordered_map& getTypeStringMap() { diff --git a/velox/type/Type.h b/velox/type/Type.h index cd5d0b0d5824..0f837536abc3 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -723,8 +723,6 @@ inline bool isDecimalKind(TypeKind typeKind) { bool isDecimalName(const std::string& typeName); -bool isDecimalTypeSignature(const std::string& arg); - std::pair getDecimalPrecisionScale(const Type& type); class UnknownType : public TypeBase {