From 65cb72a7d5298c8e1b63c64a8b6b3b04876ab581 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Mon, 7 Aug 2023 18:37:06 +0800 Subject: [PATCH] sparksql support decimal avg --- .../lib/aggregates/AverageAggregateBase.cpp | 6 +- .../lib/aggregates/DecimalAggregate.h | 11 +- .../prestosql/aggregates/SumAggregate.h | 18 +- .../sparksql/aggregates/AverageAggregate.cpp | 335 ++++++++++++++++-- .../tests/AverageAggregationTest.cpp | 127 +++++++ velox/type/DecimalUtil.h | 22 +- 6 files changed, 466 insertions(+), 53 deletions(-) diff --git a/velox/functions/lib/aggregates/AverageAggregateBase.cpp b/velox/functions/lib/aggregates/AverageAggregateBase.cpp index efef798b6202..3353caed48b6 100644 --- a/velox/functions/lib/aggregates/AverageAggregateBase.cpp +++ b/velox/functions/lib/aggregates/AverageAggregateBase.cpp @@ -21,14 +21,16 @@ namespace facebook::velox::functions::aggregate { void checkAvgIntermediateType(const TypePtr& type) { VELOX_USER_CHECK( type->isRow() || type->isVarbinary(), - "Input type for final average must be row type or varbinary type."); + "Input type for final average must be row type or varbinary type, find {}", + type->toString()); if (type->kind() == TypeKind::VARBINARY) { return; } VELOX_USER_CHECK( type->childAt(0)->kind() == TypeKind::DOUBLE || type->childAt(0)->isLongDecimal(), - "Input type for sum in final average must be double or long decimal type.") + "Input type for sum in final average must be double or long decimal type, find {}", + type->childAt(0)->toString()); VELOX_USER_CHECK_EQ( type->childAt(1)->kind(), TypeKind::BIGINT, diff --git a/velox/functions/lib/aggregates/DecimalAggregate.h b/velox/functions/lib/aggregates/DecimalAggregate.h index 695ec8f5d8ab..9d867ab1bb8d 100644 --- a/velox/functions/lib/aggregates/DecimalAggregate.h +++ b/velox/functions/lib/aggregates/DecimalAggregate.h @@ -74,11 +74,11 @@ class DecimalAggregate : public exec::Aggregate { explicit DecimalAggregate(TypePtr resultType) : exec::Aggregate(resultType) {} int32_t accumulatorFixedWidthSize() const override { - return sizeof(DecimalAggregate); + return sizeof(LongDecimalWithOverflowState); } int32_t accumulatorAlignmentSize() const override { - return static_cast(sizeof(int128_t)); + return alignof(LongDecimalWithOverflowState); } void initializeNewGroups( @@ -287,7 +287,9 @@ class DecimalAggregate : public exec::Aggregate { } virtual TResultType computeFinalValue( - LongDecimalWithOverflowState* accumulator) = 0; + LongDecimalWithOverflowState* accumulator) { + return 0; + }; void extractValues(char** groups, int32_t numGroups, VectorPtr* result) override { @@ -329,11 +331,12 @@ class DecimalAggregate : public exec::Aggregate { accumulator->count += 1; } - private: + protected: inline LongDecimalWithOverflowState* decimalAccumulator(char* group) { return exec::Aggregate::value(group); } + private: DecodedVector decodedRaw_; DecodedVector decodedPartial_; }; diff --git a/velox/functions/prestosql/aggregates/SumAggregate.h b/velox/functions/prestosql/aggregates/SumAggregate.h index 33273390a0d7..4aeb7ca51ebb 100644 --- a/velox/functions/prestosql/aggregates/SumAggregate.h +++ b/velox/functions/prestosql/aggregates/SumAggregate.h @@ -180,19 +180,11 @@ class DecimalSumAggregate virtual int128_t computeFinalValue( functions::aggregate::LongDecimalWithOverflowState* accumulator) final { - // Value is valid if the conditions below are true. - int128_t sum = accumulator->sum; - if ((accumulator->overflow == 1 && accumulator->sum < 0) || - (accumulator->overflow == -1 && accumulator->sum > 0)) { - sum = static_cast( - DecimalUtil::kOverflowMultiplier * accumulator->overflow + - accumulator->sum); - } else { - VELOX_CHECK(accumulator->overflow == 0, "Decimal overflow"); - } - - DecimalUtil::valueInRange(sum); - return sum; + auto sum = + DecimalUtil::computeValidSum(accumulator->sum, accumulator->overflow); + VELOX_CHECK(sum.has_value(), "Decimal overflow"); + DecimalUtil::valueInRange(sum.value()); + return sum.value(); } }; diff --git a/velox/functions/sparksql/aggregates/AverageAggregate.cpp b/velox/functions/sparksql/aggregates/AverageAggregate.cpp index 8c52551304be..4e334ca101a6 100644 --- a/velox/functions/sparksql/aggregates/AverageAggregate.cpp +++ b/velox/functions/sparksql/aggregates/AverageAggregate.cpp @@ -74,6 +74,277 @@ class AverageAggregate } }; +template +class DecimalAverageAggregate : public DecimalAggregate { + public: + explicit DecimalAverageAggregate(TypePtr resultType) + : DecimalAggregate(resultType) {} + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto sumVector = baseRowVector->childAt(0)->as>(); + auto countVector = baseRowVector->childAt(1)->as>(); + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + // Find overflow, set all groups to null. + rows.applyToSelected( + [&](vector_size_t i) { this->setNull(groups[i]); }); + } else { + auto sum = sumVector->valueAt(decodedIndex); + rows.applyToSelected([&](vector_size_t i) { + this->clearNull(groups[i]); + auto accumulator = this->decimalAccumulator(groups[i]); + mergeSumCount(accumulator, sum, count); + }); + } + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (decodedPartial_.isNullAt(i)) { + return; + } + this->clearNull(groups[i]); + auto decodedIndex = decodedPartial_.index(i); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + this->setNull(groups[i]); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto accumulator = this->decimalAccumulator(groups[i]); + mergeSumCount(accumulator, sum, count); + } + }); + } else { + rows.applyToSelected([&](vector_size_t i) { + this->clearNull(groups[i]); + auto decodedIndex = decodedPartial_.index(i); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + this->setNull(groups[i]); + } else { + auto sum = sumVector->valueAt(decodedIndex); + auto accumulator = this->decimalAccumulator(groups[i]); + mergeSumCount(accumulator, sum, count); + } + }); + } + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /* mayPushdown */) override { + decodedPartial_.decode(*args[0], rows); + auto baseRowVector = dynamic_cast(decodedPartial_.base()); + auto sumVector = baseRowVector->childAt(0)->as>(); + auto countVector = baseRowVector->childAt(1)->as>(); + + if (decodedPartial_.isConstantMapping()) { + if (!decodedPartial_.isNullAt(0)) { + auto decodedIndex = decodedPartial_.index(0); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + // Find overflow, just set group to null and return. + this->setNull(group); + return; + } else { + if (rows.hasSelections()) { + this->clearNull(group); + } + auto sum = sumVector->valueAt(decodedIndex); + rows.applyToSelected( + [&](vector_size_t i) { mergeAccumulators(group, sum, count); }); + } + } + } else if (decodedPartial_.mayHaveNulls()) { + rows.applyToSelected([&](vector_size_t i) { + if (!decodedPartial_.isNullAt(i)) { + this->clearNull(group); + auto decodedIndex = decodedPartial_.index(i); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + // Find overflow, just set group to null. + this->setNull(group); + } else { + auto sum = sumVector->valueAt(decodedIndex); + mergeAccumulators(group, sum, count); + } + } + }); + } else { + if (rows.hasSelections()) { + this->clearNull(group); + } + rows.applyToSelected([&](vector_size_t i) { + auto decodedIndex = decodedPartial_.index(i); + auto count = countVector->valueAt(decodedIndex); + if (sumVector->isNullAt(decodedIndex) && + !countVector->isNullAt(decodedIndex) && count > 0) { + // Find overflow, just set group to null. + this->setNull(group); + } else { + auto sum = sumVector->valueAt(decodedIndex); + mergeAccumulators(group, sum, count); + } + }); + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto rowVector = (*result)->as(); + auto sumVector = rowVector->childAt(0)->asFlatVector(); + auto countVector = rowVector->childAt(1)->asFlatVector(); + rowVector->resize(numGroups); + sumVector->resize(numGroups); + countVector->resize(numGroups); + rowVector->clearAllNulls(); + + int64_t* rawCounts = countVector->mutableRawValues(); + int128_t* rawSums = sumVector->mutableRawValues(); + for (auto i = 0; i < numGroups; ++i) { + char* group = groups[i]; + auto* accumulator = this->decimalAccumulator(group); + std::optional validSum = + DecimalUtil::computeValidSum(accumulator->sum, accumulator->overflow); + if (validSum.has_value()) { + rawCounts[i] = accumulator->count; + rawSums[i] = validSum.value(); + } else { + // Find overflow. + sumVector->setNull(i, true); + rawCounts[i] = accumulator->count; + } + } + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + auto vector = (*result)->as>(); + VELOX_CHECK(vector); + vector->resize(numGroups); + uint64_t* rawNulls = this->getRawNulls(vector); + + TResultType* rawValues = vector->mutableRawValues(); + for (int32_t i = 0; i < numGroups; ++i) { + char* group = groups[i]; + auto accumulator = this->decimalAccumulator(group); + if (accumulator->count == 0) { + // In Spark, if all inputs are null, count will be 0, + // and the result of final avg will be null. + vector->setNull(i, true); + } else { + this->clearNull(rawNulls, i); + std::optional avg = computeAvg(accumulator); + if (avg.has_value()) { + rawValues[i] = avg.value(); + } else { + // Find overflow. + vector->setNull(i, true); + } + } + } + } + + std::optional computeAvg( + LongDecimalWithOverflowState* accumulator) { + std::optional validSum = + DecimalUtil::computeValidSum(accumulator->sum, accumulator->overflow); + if (!validSum.has_value()) { + return std::nullopt; + } + + auto [resultPrecision, resultScale] = + getDecimalPrecisionScale(*this->resultType().get()); + // Spark use DECIMAL(20,0) to represent long value + const uint8_t countPrecision = 20, countScale = 0; + uint8_t sumPrecision = resultPrecision - 4 + 10; + uint8_t sumScale = resultScale - 4; + auto [avgPrecision, avgScale] = computeResultPrecisionScale( + sumPrecision, sumScale, countPrecision, countScale); + auto sumRescale = computeRescaleFactor(sumScale, countScale, avgScale); + auto countDecimal = accumulator->count; + int128_t avg = 0; + + DecimalUtil::divideWithRoundUp( + avg, validSum.value(), countDecimal, false, sumRescale, 0); + return DecimalUtil::rescaleWithRoundUp( + avg, avgPrecision, avgScale, resultPrecision, resultScale); + } + + private: + template + inline void mergeSumCount( + LongDecimalWithOverflowState* accumulator, + UnscaledType sum, + int64_t count) { + accumulator->count += count; + accumulator->overflow += + DecimalUtil::addWithOverflow(accumulator->sum, sum, accumulator->sum); + } + + template + void mergeAccumulators( + char* group, + const UnscaledType& otherSum, + const int64_t& otherCount) { + if constexpr (tableHasNulls) { + exec::Aggregate::clearNull(group); + } + auto accumulator = this->decimalAccumulator(group); + mergeSumCount(accumulator, otherSum, otherCount); + } + + inline static uint8_t + computeRescaleFactor(uint8_t fromScale, uint8_t toScale, uint8_t rScale) { + return rScale - fromScale + toScale; + } + + inline static std::pair computeResultPrecisionScale( + const uint8_t aPrecision, + const uint8_t aScale, + const uint8_t bPrecision, + const uint8_t bScale) { + uint8_t intDig = aPrecision - aScale + bScale; + uint8_t scale = std::max(6, aScale + bPrecision + 1); + uint8_t precision = intDig + scale; + return adjustPrecisionScale(precision, scale); + } + + inline static std::pair adjustPrecisionScale( + const uint8_t precision, + const uint8_t scale) { + VELOX_CHECK(precision >= scale); + if (precision <= 38) { + return {precision, scale}; + } else { + uint8_t intDigits = precision - scale; + uint8_t minScaleValue = std::min(scale, (uint8_t)6); + uint8_t adjustedScale = + std::max((uint8_t)(38 - intDigits), minScaleValue); + return {38, adjustedScale}; + } + } + + DecodedVector decodedRaw_; + DecodedVector decodedPartial_; +}; + } // namespace /// Count is BIGINT() while sum and the final aggregates type depends on @@ -96,13 +367,16 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) { .build()); } - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .integerVariable("a_precision") - .integerVariable("a_scale") - .argumentType("DECIMAL(a_precision, a_scale)") - .intermediateType("varbinary") - .returnType("DECIMAL(a_precision, a_scale)") - .build()); + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("r_precision", "min(38, a_precision + 4)") + .integerVariable("r_scale", "min(38, a_scale + 4)") + .argumentType("DECIMAL(a_precision, a_scale)") + .intermediateType("ROW(DECIMAL(38 , a_scale), BIGINT)") + .returnType("DECIMAL(r_precision, r_scale)") + .build()); return exec::registerAggregateFunction( name, @@ -115,7 +389,7 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) { -> std::unique_ptr { VELOX_CHECK_LE( argTypes.size(), 1, "{} takes at most one argument", name); - auto inputType = argTypes[0]; + const auto& inputType = argTypes[0]; if (exec::isRawInput(step)) { switch (inputType->kind()) { case TypeKind::SMALLINT: @@ -126,16 +400,28 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) { AverageAggregate>(resultType); case TypeKind::BIGINT: { if (inputType->isShortDecimal()) { - return std::make_unique>( - resultType); + if (exec::isPartialOutput(step)) { + return std::make_unique< + DecimalAverageAggregate>(resultType); + } else { + if (resultType->isShortDecimal()) { + return std::make_unique< + DecimalAverageAggregate>(resultType); + } else if (resultType->isLongDecimal()) { + return std::make_unique< + DecimalAverageAggregate>(resultType); + } else { + VELOX_FAIL("Result type must be decimal"); + } + } } return std::make_unique< AverageAggregate>(resultType); } case TypeKind::HUGEINT: { if (inputType->isLongDecimal()) { - return std::make_unique>( - resultType); + return std::make_unique< + DecimalAverageAggregate>(resultType); } VELOX_NYI(); } @@ -159,27 +445,18 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) { resultType); case TypeKind::DOUBLE: case TypeKind::ROW: + if (inputType->childAt(0)->isLongDecimal()) { + return std::make_unique< + DecimalAverageAggregate>(resultType); + } return std::make_unique< AverageAggregate>(resultType); case TypeKind::BIGINT: - return std::make_unique>( - resultType); + return std::make_unique< + DecimalAverageAggregate>(resultType); case TypeKind::HUGEINT: - return std::make_unique>( - resultType); - case TypeKind::VARBINARY: - if (inputType->isLongDecimal()) { - return std::make_unique>( - resultType); - } else if ( - inputType->isShortDecimal() || - inputType->kind() == TypeKind::VARBINARY) { - // If the input and out type are VARBINARY, then the - // LongDecimalWithOverflowState is used and the template type - // does not matter. - return std::make_unique>( - resultType); - } + return std::make_unique< + DecimalAverageAggregate>(resultType); default: VELOX_FAIL( "Unsupported result type for final aggregation: {}", diff --git a/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp index fdef8a189bf5..28464885ea02 100644 --- a/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/AverageAggregationTest.cpp @@ -111,5 +111,132 @@ TEST_F(AverageAggregationTest, avgAllNulls) { assertQuery(plan, expected); } +TEST_F(AverageAggregationTest, avgDecimal) { + int64_t kRescale = DecimalUtil::kPowersOfTen[4]; + // Short decimal aggregation + auto shortDecimal = makeNullableFlatVector( + {1'000, 2'000, 3'000, 4'000, 5'000, std::nullopt}, DECIMAL(10, 1)); + testAggregations( + {makeRowVector({shortDecimal})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeNullableFlatVector( + {3'000 * kRescale}, DECIMAL(14, 5))})}); + + // Long decimal aggregation + testAggregations( + {makeRowVector({makeNullableFlatVector( + {HugeInt::build(10, 100), + HugeInt::build(10, 200), + HugeInt::build(10, 300), + HugeInt::build(10, 400), + HugeInt::build(10, 500), + std::nullopt}, + DECIMAL(23, 4))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeFlatVector( + std::vector{HugeInt::build(10, 300) * kRescale}, + DECIMAL(27, 8))})}); + + // The total sum overflows the max int128_t limit. + std::vector rawVector; + for (int i = 0; i < 10; ++i) { + rawVector.push_back(DecimalUtil::kLongDecimalMax); + } + testAggregations( + {makeRowVector({makeFlatVector(rawVector, DECIMAL(38, 0))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeNullableFlatVector( + std::vector>{std::nullopt}, + DECIMAL(38, 4))})}); + + // The total sum underflows the min int128_t limit. + rawVector.clear(); + auto underFlowTestResult = makeNullableFlatVector( + std::vector>{std::nullopt}, DECIMAL(38, 4)); + for (int i = 0; i < 10; ++i) { + rawVector.push_back(DecimalUtil::kLongDecimalMin); + } + testAggregations( + {makeRowVector({makeFlatVector(rawVector, DECIMAL(38, 0))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({underFlowTestResult})}); + + // Test constant vector. + testAggregations( + {makeRowVector({makeConstant(100, 10, DECIMAL(10, 2))})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeFlatVector( + std::vector{100 * kRescale}, DECIMAL(14, 6))})}); + + auto newSize = shortDecimal->size() * 2; + auto indices = makeIndices(newSize, [&](int row) { return row / 2; }); + auto dictVector = + VectorTestBase::wrapInDictionary(indices, newSize, shortDecimal); + + testAggregations( + {makeRowVector({dictVector})}, + {}, + {"spark_avg(c0)"}, + {}, + {makeRowVector({makeFlatVector( + std::vector{3'000 * kRescale}, DECIMAL(14, 5))})}); + + // Decimal average aggregation with multiple groups. + auto inputRows = { + makeRowVector( + {makeNullableFlatVector({1, 1}), + makeFlatVector({37220, 53450}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 2}), + makeFlatVector({10410, 9250}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({3, 3}), + makeFlatVector({-12783, 0}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({1, 2}), + makeFlatVector({23178, 41093}, DECIMAL(15, 2))}), + makeRowVector( + {makeNullableFlatVector({2, 3}), + makeFlatVector({-10023, 5290}, DECIMAL(15, 2))}), + }; + + auto expectedResult = { + makeRowVector( + {makeNullableFlatVector({1}), + makeFlatVector(std::vector{379493333}, DECIMAL(19, 6))}), + makeRowVector( + {makeNullableFlatVector({2}), + makeFlatVector(std::vector{126825000}, DECIMAL(19, 6))}), + makeRowVector( + {makeNullableFlatVector({3}), + makeFlatVector(std::vector{-24976667}, DECIMAL(19, 6))})}; + + testAggregations(inputRows, {"c0"}, {"spark_avg(c1)"}, expectedResult); +} + +TEST_F(AverageAggregationTest, avgDecimalWithMultipleRowVectors) { + int64_t kRescale = DecimalUtil::kPowersOfTen[4]; + auto inputRows = { + makeRowVector({makeFlatVector({100, 200}, DECIMAL(15, 2))}), + makeRowVector({makeFlatVector({300, 400}, DECIMAL(15, 2))}), + makeRowVector({makeFlatVector({500, 600}, DECIMAL(15, 2))}), + }; + + auto expectedResult = {makeRowVector( + {makeFlatVector(std::vector{350 * kRescale}, DECIMAL(19, 6))})}; + + testAggregations(inputRows, {}, {"spark_avg(c0)"}, expectedResult); +} + } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index adb90bdb9302..fd57b5df4890 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -117,11 +117,7 @@ class DecimalUtil { // Check overflow. if (rescaledValue < -DecimalUtil::kPowersOfTen[toPrecision] || rescaledValue > DecimalUtil::kPowersOfTen[toPrecision] || isOverflow) { - VELOX_USER_FAIL( - "Cannot cast DECIMAL '{}' to DECIMAL({},{})", - DecimalUtil::toString(inputValue, DECIMAL(fromPrecision, fromScale)), - toPrecision, - toScale); + return std::nullopt; } return static_cast(rescaledValue); } @@ -249,6 +245,22 @@ class DecimalUtil { } } + inline static std::optional computeValidSum( + int128_t sum, + int64_t overflow) { + // Value is valid if the conditions below are true. + int128_t validSum = sum; + if ((overflow == 1 && sum < 0) || (overflow == -1 && sum > 0)) { + validSum = static_cast( + DecimalUtil::kOverflowMultiplier * overflow + sum); + } else { + if (overflow != 0) { + return std::nullopt; + } + } + return validSum; + } + /// Origins from java side BigInteger#bitLength. /// /// Returns the number of bits in the minimal two's-complement