Skip to content

Commit

Permalink
sparksql support decimal avg
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Aug 16, 2023
1 parent 1aab55f commit 65cb72a
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 53 deletions.
6 changes: 4 additions & 2 deletions velox/functions/lib/aggregates/AverageAggregateBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ namespace facebook::velox::functions::aggregate {
void checkAvgIntermediateType(const TypePtr& type) {
VELOX_USER_CHECK(
type->isRow() || type->isVarbinary(),
"Input type for final average must be row type or varbinary type.");
"Input type for final average must be row type or varbinary type, find {}",
type->toString());
if (type->kind() == TypeKind::VARBINARY) {
return;
}
VELOX_USER_CHECK(
type->childAt(0)->kind() == TypeKind::DOUBLE ||
type->childAt(0)->isLongDecimal(),
"Input type for sum in final average must be double or long decimal type.")
"Input type for sum in final average must be double or long decimal type, find {}",
type->childAt(0)->toString());
VELOX_USER_CHECK_EQ(
type->childAt(1)->kind(),
TypeKind::BIGINT,
Expand Down
11 changes: 7 additions & 4 deletions velox/functions/lib/aggregates/DecimalAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class DecimalAggregate : public exec::Aggregate {
explicit DecimalAggregate(TypePtr resultType) : exec::Aggregate(resultType) {}

int32_t accumulatorFixedWidthSize() const override {
return sizeof(DecimalAggregate);
return sizeof(LongDecimalWithOverflowState);
}

int32_t accumulatorAlignmentSize() const override {
return static_cast<int32_t>(sizeof(int128_t));
return alignof(LongDecimalWithOverflowState);
}

void initializeNewGroups(
Expand Down Expand Up @@ -287,7 +287,9 @@ class DecimalAggregate : public exec::Aggregate {
}

virtual TResultType computeFinalValue(
LongDecimalWithOverflowState* accumulator) = 0;
LongDecimalWithOverflowState* accumulator) {
return 0;
};

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
Expand Down Expand Up @@ -329,11 +331,12 @@ class DecimalAggregate : public exec::Aggregate {
accumulator->count += 1;
}

private:
protected:
inline LongDecimalWithOverflowState* decimalAccumulator(char* group) {
return exec::Aggregate::value<LongDecimalWithOverflowState>(group);
}

private:
DecodedVector decodedRaw_;
DecodedVector decodedPartial_;
};
Expand Down
18 changes: 5 additions & 13 deletions velox/functions/prestosql/aggregates/SumAggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,19 +180,11 @@ class DecimalSumAggregate

virtual int128_t computeFinalValue(
functions::aggregate::LongDecimalWithOverflowState* accumulator) final {
// Value is valid if the conditions below are true.
int128_t sum = accumulator->sum;
if ((accumulator->overflow == 1 && accumulator->sum < 0) ||
(accumulator->overflow == -1 && accumulator->sum > 0)) {
sum = static_cast<int128_t>(
DecimalUtil::kOverflowMultiplier * accumulator->overflow +
accumulator->sum);
} else {
VELOX_CHECK(accumulator->overflow == 0, "Decimal overflow");
}

DecimalUtil::valueInRange(sum);
return sum;
auto sum =
DecimalUtil::computeValidSum(accumulator->sum, accumulator->overflow);
VELOX_CHECK(sum.has_value(), "Decimal overflow");
DecimalUtil::valueInRange(sum.value());
return sum.value();
}
};

Expand Down
Loading

0 comments on commit 65cb72a

Please sign in to comment.