Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Nov 17, 2023
1 parent febe46c commit 8a4ba83
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
24 changes: 17 additions & 7 deletions velox/functions/sparksql/aggregates/AverageAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "velox/functions/sparksql/aggregates/AverageAggregate.h"
#include "velox/functions/lib/aggregates/AverageAggregateBase.h"
#include "velox/functions/sparksql/DecimalUtil.h"

using namespace facebook::velox::functions::aggregate;

Expand Down Expand Up @@ -218,7 +219,7 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {
char* group = groups[i];
auto* accumulator = this->decimalAccumulator(group);
std::optional<int128_t> validSum =
DecimalUtil::computeValidSum(accumulator->sum, accumulator->overflow);
DecimalUtil::adjustSumForOverflow(accumulator->sum, accumulator->overflow);
if (validSum.has_value()) {
rawCounts[i] = accumulator->count;
rawSums[i] = validSum.value();
Expand Down Expand Up @@ -260,8 +261,8 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {

std::optional<TResultType> computeAvg(
LongDecimalWithOverflowState* accumulator) {
std::optional<int128_t> validSum =
DecimalUtil::computeValidSum(accumulator->sum, accumulator->overflow);
std::optional<int128_t> validSum = DecimalUtil::adjustSumForOverflow(
accumulator->sum, accumulator->overflow);
if (!validSum.has_value()) {
return std::nullopt;
}
Expand All @@ -278,10 +279,19 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {
auto countDecimal = accumulator->count;
int128_t avg = 0;

DecimalUtil::divideWithRoundUp<int128_t, int128_t, int128_t>(
avg, validSum.value(), countDecimal, false, sumRescale, 0);
return DecimalUtil::rescaleWithRoundUp<int128_t, TResultType>(
avg, avgPrecision, avgScale, resultPrecision, resultScale);
bool overflow = false;
functions::sparksql::DecimalUtil::divideWithRoundUp<int128_t, int128_t, int128_t>(
avg, validSum.value(), countDecimal, sumRescale, overflow);
if (overflow) {
return std::nullopt;
}
auto rescaledValue = DecimalUtil::rescaleWithRoundUp<int128_t, TResultType>(
avg,
avgPrecision,
avgScale,
resultPrecision,
resultScale);
return overflow ? std::nullopt : rescaledValue;
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ TEST_F(AverageAggregationTest, avgDecimal) {
{"spark_avg(c0)"},
{},
{makeRowVector({makeNullableFlatVector<int64_t>(
{3'000 * kRescale}, DECIMAL(14, 5))})});
{3'000 * kRescale}, DECIMAL(14, 5))})},
/*config*/ {},
/*testWithTableScan*/ false);

// Long decimal aggregation
testAggregations(
Expand All @@ -139,7 +141,9 @@ TEST_F(AverageAggregationTest, avgDecimal) {
{},
{makeRowVector({makeFlatVector(
std::vector<int128_t>{HugeInt::build(10, 300) * kRescale},
DECIMAL(27, 8))})});
DECIMAL(27, 8))})},
/*config*/ {},
/*testWithTableScan*/ false);

// The total sum overflows the max int128_t limit.
std::vector<int128_t> rawVector;
Expand All @@ -153,7 +157,9 @@ TEST_F(AverageAggregationTest, avgDecimal) {
{},
{makeRowVector({makeNullableFlatVector(
std::vector<std::optional<int128_t>>{std::nullopt},
DECIMAL(38, 4))})});
DECIMAL(38, 4))})},
/*config*/ {},
/*testWithTableScan*/ false);

// The total sum underflows the min int128_t limit.
rawVector.clear();
Expand All @@ -167,7 +173,9 @@ TEST_F(AverageAggregationTest, avgDecimal) {
{},
{"spark_avg(c0)"},
{},
{makeRowVector({underFlowTestResult})});
{makeRowVector({underFlowTestResult})},
/*config*/ {},
/*testWithTableScan*/ false);

// Test constant vector.
testAggregations(
Expand All @@ -176,7 +184,9 @@ TEST_F(AverageAggregationTest, avgDecimal) {
{"spark_avg(c0)"},
{},
{makeRowVector({makeFlatVector(
std::vector<int64_t>{100 * kRescale}, DECIMAL(14, 6))})});
std::vector<int64_t>{100 * kRescale}, DECIMAL(14, 6))})},
/*config*/ {},
/*testWithTableScan*/ false);

auto newSize = shortDecimal->size() * 2;
auto indices = makeIndices(newSize, [&](int row) { return row / 2; });
Expand All @@ -189,7 +199,9 @@ TEST_F(AverageAggregationTest, avgDecimal) {
{"spark_avg(c0)"},
{},
{makeRowVector({makeFlatVector(
std::vector<int64_t>{3'000 * kRescale}, DECIMAL(14, 5))})});
std::vector<int64_t>{3'000 * kRescale}, DECIMAL(14, 5))})},
/*config*/ {},
/*testWithTableScan*/ false);

// Decimal average aggregation with multiple groups.
auto inputRows = {
Expand Down Expand Up @@ -221,7 +233,13 @@ TEST_F(AverageAggregationTest, avgDecimal) {
{makeNullableFlatVector<int32_t>({3}),
makeFlatVector(std::vector<int128_t>{-24976667}, DECIMAL(19, 6))})};

testAggregations(inputRows, {"c0"}, {"spark_avg(c1)"}, expectedResult);
testAggregations(
inputRows,
{"c0"},
{"spark_avg(c1)"},
expectedResult,
/*config*/ {},
/*testWithTableScan*/ false);
}

TEST_F(AverageAggregationTest, avgDecimalWithMultipleRowVectors) {
Expand All @@ -235,7 +253,13 @@ TEST_F(AverageAggregationTest, avgDecimalWithMultipleRowVectors) {
auto expectedResult = {makeRowVector(
{makeFlatVector(std::vector<int128_t>{350 * kRescale}, DECIMAL(19, 6))})};

testAggregations(inputRows, {}, {"spark_avg(c0)"}, expectedResult);
testAggregations(
inputRows,
{},
{"spark_avg(c0)"},
expectedResult,
/*config*/ {},
/*testWithTableScan*/ false);
}

} // namespace
Expand Down

0 comments on commit 8a4ba83

Please sign in to comment.