diff --git a/velox/core/SimpleFunctionMetadata.h b/velox/core/SimpleFunctionMetadata.h index 2778e8b17bc3..ac18fa0834db 100644 --- a/velox/core/SimpleFunctionMetadata.h +++ b/velox/core/SimpleFunctionMetadata.h @@ -27,9 +27,6 @@ #include "velox/type/Type.h" #include "velox/type/Variant.h" -inline constexpr char kPrecisionVariable[] = "a_precision"; -inline constexpr char kScaleVariable[] = "a_scale"; - namespace facebook::velox::core { // Most UDFs are deterministic, hence this default value. @@ -194,14 +191,8 @@ struct TypeAnalysis { SimpleTypeTrait::isPrimitiveType || SimpleTypeTrait::typeKind == TypeKind::OPAQUE); results.stats.concreteCount++; - if (isDecimalKind(SimpleTypeTrait::typeKind)) { - results.out << detail::strToLowerCopy( - std::string(SimpleTypeTrait::name)) - << "(" << kPrecisionVariable << "," << kScaleVariable << ")"; - } else { - results.out << detail::strToLowerCopy( - std::string(SimpleTypeTrait::name)); - } + results.out << detail::strToLowerCopy( + std::string(SimpleTypeTrait::name)); } }; @@ -443,15 +434,8 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { builder.returnType(analysis.outputType); - bool isDecimalArg = false; for (const auto& arg : analysis.argsTypes) { builder.argumentType(arg); - isDecimalArg |= isDecimalTypeSignature(arg); - } - - if (isDecimalArg) { - builder.integerVariable(kPrecisionVariable); - builder.integerVariable(kScaleVariable); } for (const auto& variable : analysis.variables) { diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index da946bfa3f86..4e6864b5654c 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -28,7 +28,7 @@ add_library( ArraySort.cpp ArraySum.cpp Comparisons.cpp - DecimalArithmetic.cpp + DecimalVectorFunctions.cpp ElementAt.cpp FilterFunctions.cpp FromUnixTime.cpp diff --git a/velox/functions/prestosql/Comparisons.cpp b/velox/functions/prestosql/Comparisons.cpp index 1cba6fbd5ed8..b52ccb666f76 100644 --- a/velox/functions/prestosql/Comparisons.cpp +++ b/velox/functions/prestosql/Comparisons.cpp @@ -97,9 +97,10 @@ struct SimdComparator { template < TypeKind kind, typename std::enable_if_t< - xsimd::has_simd_register< - typename TypeTraits::NativeType>::value && - kind != TypeKind::BOOLEAN, + (xsimd::has_simd_register< + typename TypeTraits::NativeType>::value && + kind != TypeKind::BOOLEAN) || + kind == TypeKind::SHORT_DECIMAL || kind == TypeKind::LONG_DECIMAL, int> = 0> void applyComparison( const SelectivityVector& rows, @@ -112,11 +113,16 @@ struct SimdComparator { auto resultVector = result->asUnchecked>(); auto rawResult = resultVector->mutableRawValues(); + // UnscaledShortDecimal and UnscaledLongDecimal will soon be replaced with + // int64_t and int128_t respectively. This change will be removed then. + constexpr bool isDecimal = + (std::is_same_v || + std::is_same_v); auto isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) && (rhs.isConstantEncoding() || rhs.isFlatEncoding()) && rows.isAllSelected(); - if (!isSimdizable) { + if (!isSimdizable || isDecimal) { exec::LocalDecodedVector lhsDecoded(context, lhs, rows); exec::LocalDecodedVector rhsDecoded(context, rhs, rows); @@ -128,38 +134,40 @@ struct SimdComparator { }); return; } + if constexpr (!isDecimal) { + if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) { + auto l = lhs.asUnchecked>()->valueAt(0); + auto r = rhs.asUnchecked>()->valueAt(0); + applySimdComparison( + rows.begin(), rows.end(), &l, &r, rawResult); + } else if (lhs.isConstantEncoding()) { + auto l = lhs.asUnchecked>()->valueAt(0); + auto rawRhs = rhs.asUnchecked>()->rawValues(); + applySimdComparison( + rows.begin(), rows.end(), &l, rawRhs, rawResult); + } else if (rhs.isConstantEncoding()) { + auto rawLhs = lhs.asUnchecked>()->rawValues(); + auto r = rhs.asUnchecked>()->valueAt(0); + applySimdComparison( + rows.begin(), rows.end(), rawLhs, &r, rawResult); + } else { + auto rawLhs = lhs.asUnchecked>()->rawValues(); + auto rawRhs = rhs.asUnchecked>()->rawValues(); + applySimdComparison( + rows.begin(), rows.end(), rawLhs, rawRhs, rawResult); + } - if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) { - auto l = lhs.asUnchecked>()->valueAt(0); - auto r = rhs.asUnchecked>()->valueAt(0); - applySimdComparison( - rows.begin(), rows.end(), &l, &r, rawResult); - } else if (lhs.isConstantEncoding()) { - auto l = lhs.asUnchecked>()->valueAt(0); - auto rawRhs = rhs.asUnchecked>()->rawValues(); - applySimdComparison( - rows.begin(), rows.end(), &l, rawRhs, rawResult); - } else if (rhs.isConstantEncoding()) { - auto rawLhs = lhs.asUnchecked>()->rawValues(); - auto r = rhs.asUnchecked>()->valueAt(0); - applySimdComparison( - rows.begin(), rows.end(), rawLhs, &r, rawResult); - } else { - auto rawLhs = lhs.asUnchecked>()->rawValues(); - auto rawRhs = rhs.asUnchecked>()->rawValues(); - applySimdComparison( - rows.begin(), rows.end(), rawLhs, rawRhs, rawResult); + resultVector->clearNulls(rows); } - - resultVector->clearNulls(rows); } template < TypeKind kind, typename std::enable_if_t< - !xsimd::has_simd_register< - typename TypeTraits::NativeType>::value || - kind == TypeKind::BOOLEAN, + (!xsimd::has_simd_register< + typename TypeTraits::NativeType>::value || + kind == TypeKind::BOOLEAN) && + kind != TypeKind::SHORT_DECIMAL && kind != TypeKind::LONG_DECIMAL, int> = 0> void applyComparison( const SelectivityVector& /* rows */, @@ -187,6 +195,16 @@ class ComparisonSimdFunction : public exec::VectorFunction { context.ensureWritable(rows, outputType, result); auto comparator = SimdComparator{}; + if (args[0]->type()->isShortDecimal()) { + comparator.template applyComparison( + rows, *args[0], *args[1], context, result); + return; + } else if (args[0]->type()->isLongDecimal()) { + comparator.template applyComparison( + rows, *args[0], *args[1], context, result); + return; + } + VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( comparator.template applyComparison, args[0]->typeKind(), @@ -208,7 +226,13 @@ class ComparisonSimdFunction : public exec::VectorFunction { .argumentType(inputType) .build()); } - + signatures.push_back(exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .returnType("boolean") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .build()); return signatures; } diff --git a/velox/functions/prestosql/DecimalArithmetic.cpp b/velox/functions/prestosql/DecimalVectorFunctions.cpp similarity index 85% rename from velox/functions/prestosql/DecimalArithmetic.cpp rename to velox/functions/prestosql/DecimalVectorFunctions.cpp index 5246188989cb..b1af12246f79 100644 --- a/velox/functions/prestosql/DecimalArithmetic.cpp +++ b/velox/functions/prestosql/DecimalVectorFunctions.cpp @@ -146,10 +146,59 @@ class DecimalUnaryBaseFunction : public exec::VectorFunction { return result->asUnchecked>()->mutableRawValues(); } - private: const uint8_t aRescale_; }; +template +class DecimalBetweenFunction : public exec::VectorFunction { + public: + DecimalBetweenFunction() {} + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& resultType, + exec::EvalCtx& context, + VectorPtr& result) const override { + prepareResults(rows, resultType, context, result); + // Second and third arguments must always be constant. + VELOX_CHECK(args[1]->isConstantEncoding() && args[2]->isConstantEncoding()); + auto constantB = args[1]->asUnchecked>()->valueAt(0); + auto constantC = args[2]->asUnchecked>()->valueAt(0); + if (args[0]->isFlatEncoding()) { + // Fast path if first argument is flat. + auto flatA = args[0]->asUnchecked>(); + auto rawA = flatA->mutableRawValues(); + context.applyToSelectedNoThrow(rows, [&](auto row) { + result->asUnchecked>()->set( + row, + rawA[row].unscaledValue() >= constantB.unscaledValue() && + rawA[row].unscaledValue() <= constantC.unscaledValue()); + }); + } else { + // Path if first argument is encoded. + exec::DecodedArgs decodedArgs(rows, args, context); + auto a = decodedArgs.at(0); + context.applyToSelectedNoThrow(rows, [&](auto row) { + auto value = a->valueAt(row); + result->asUnchecked>()->set( + row, + value.unscaledValue() >= constantB.unscaledValue() && + value.unscaledValue() <= constantC.unscaledValue()); + }); + } + } + + private: + void prepareResults( + const SelectivityVector& rows, + const TypePtr& resultType, + exec::EvalCtx& context, + VectorPtr& result) const { + context.ensureWritable(rows, resultType, result); + result->clearNulls(rows); + } +}; + class Addition { public: template @@ -431,6 +480,18 @@ decimalAbsNegateSignature() { .build()}; } +std::vector> +decimalBetweenSignature() { + return {exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .returnType("BOOLEAN") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .build()}; +} + template std::shared_ptr createDecimalUnary( const std::string& /*name*/, @@ -516,6 +577,26 @@ std::shared_ptr createDecimalFunction( } VELOX_UNSUPPORTED(); } + +std::shared_ptr createDecimalBetweenFunction( + const std::string& name, + const std::vector& inputArgs) { + auto aType = inputArgs[0].type; + auto bType = inputArgs[1].type; + auto cType = inputArgs[2].type; + if (aType->kind() == TypeKind::SHORT_DECIMAL) { + VELOX_CHECK(bType->kind() == TypeKind::SHORT_DECIMAL); + VELOX_CHECK(cType->kind() == TypeKind::SHORT_DECIMAL); + // Arguments are short decimals. + return std::make_shared>(); + } else { + VELOX_CHECK(bType->kind() == TypeKind::LONG_DECIMAL); + VELOX_CHECK(cType->kind() == TypeKind::LONG_DECIMAL); + // Arguments are long decimals. + return std::make_shared>(); + } + VELOX_UNSUPPORTED(); +} }; // namespace VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( @@ -552,4 +633,9 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_decimal_negate, decimalAbsNegateSignature(), createDecimalUnary); + +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( + udf_decimal_between, + decimalBetweenSignature(), + createDecimalBetweenFunction); }; // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp index c95468442ed8..3e43371646d1 100644 --- a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp @@ -58,47 +58,8 @@ void registerComparisonFunctions(const std::string& prefix) { {prefix + "between"}); registerFunction( {prefix + "between"}); - registerFunction< - BetweenFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "between"}); - registerFunction< - BetweenFunction, - bool, - UnscaledLongDecimal, - UnscaledLongDecimal, - UnscaledLongDecimal>({prefix + "between"}); - registerFunction< - GtFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "gt"}); - registerFunction( - {prefix + "gt"}); - registerFunction< - LtFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "lt"}); - registerFunction( - {prefix + "lt"}); - registerFunction< - GteFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "gte"}); - registerFunction( - {prefix + "gte"}); - registerFunction< - LteFunction, - bool, - UnscaledShortDecimal, - UnscaledShortDecimal>({prefix + "lte"}); - registerFunction( - {prefix + "lte"}); + VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_between, prefix + "between"); } } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/tests/ComparisonsTest.cpp b/velox/functions/prestosql/tests/ComparisonsTest.cpp index 682c109d5629..d1b8c14f9063 100644 --- a/velox/functions/prestosql/tests/ComparisonsTest.cpp +++ b/velox/functions/prestosql/tests/ComparisonsTest.cpp @@ -107,7 +107,7 @@ TEST_F(ComparisonsTest, betweenDecimal) { VectorPtr input, VectorPtr expectedResult) { auto actual = evaluate>(exprStr, makeRowVector({input})); - test::assertEqualVectors(actual, expectedResult); + test::assertEqualVectors(expectedResult, actual); }; auto shortFlat = makeNullableShortDecimalFlatVector( @@ -120,6 +120,11 @@ TEST_F(ComparisonsTest, betweenDecimal) { auto longFlat = makeNullableLongDecimalFlatVector( {100, 250, 300, 500, std::nullopt}, DECIMAL(20, 2)); + runAndCompare( + "c0 between cast(2.00 as DECIMAL(20, 2)) and cast(3.00 as DECIMAL(20, 2))", + longFlat, + expectedResult); + // Comparing LONG_DECIMAL and SHORT_DECIMAL must throw error. VELOX_ASSERT_THROW( runAndCompare("c0 between 2.00 and 3.00", longFlat, expectedResult), @@ -159,7 +164,7 @@ TEST_F(ComparisonsTest, gtLtDecimal) { std::vector& inputs, VectorPtr expectedResult) { auto actual = evaluate>(expr, makeRowVector(inputs)); - test::assertEqualVectors(actual, expectedResult); + test::assertEqualVectors(expectedResult, actual); }; // Short Decimals test. diff --git a/velox/type/Type.cpp b/velox/type/Type.cpp index b2231af0b04a..989aee4333e7 100644 --- a/velox/type/Type.cpp +++ b/velox/type/Type.cpp @@ -47,15 +47,6 @@ bool isDecimalName(const std::string& typeName) { typeNameUpper == TypeTraits::name); } -bool isDecimalTypeSignature(const std::string& arg) { - auto upper = boost::algorithm::to_upper_copy(arg); - return ( - upper.find(TypeTraits::name) != - std::string::npos || - upper.find(TypeTraits::name) != - std::string::npos); -} - // Static variable intialization is not thread safe for non // constant-initialization, but scoped static initialization is thread safe. const std::unordered_map& getTypeStringMap() { diff --git a/velox/type/Type.h b/velox/type/Type.h index cd5d0b0d5824..0f837536abc3 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -723,8 +723,6 @@ inline bool isDecimalKind(TypeKind typeKind) { bool isDecimalName(const std::string& typeName); -bool isDecimalTypeSignature(const std::string& arg); - std::pair getDecimalPrecisionScale(const Type& type); class UnknownType : public TypeBase {