Skip to content

Commit

Permalink
fix decimal avg intermediate output type (apache#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored Mar 16, 2023
1 parent 77bb133 commit 7c6f73b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ package io.glutenproject.utils

import scala.math.min

import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType}
import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE}

object GlutenDecimalUtil {
Expand All @@ -29,4 +30,11 @@ object GlutenDecimalUtil {
def bounded(precision: Int, scale: Int): DecimalType = {
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
}

def getAvgSumDataType(avg: Average): DataType = avg.dataType match {
// avg.dataType is Decimal(p + 4, s + 4) and sumType is Decimal(p + 10, s)
// we need to get sumType, so p = p - 4 + 10 and s = s - 4
case _ @ GlutenDecimalUtil.Fixed(p, s) => GlutenDecimalUtil.bounded(p - 4 + 10, s - 4)
case _ => DoubleType
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ case class GlutenHashAggregateExecTransformer(
aggregateFunction match {
case avg: Average =>
structTypeNodes.add(ConverterUtils.getTypeNode(
avg.dataType match {
case _ @ GlutenDecimalUtil.Fixed(p, s) => GlutenDecimalUtil.bounded(p + 10, s)
case _ => DoubleType
}, nullable = true))
GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
Expand Down Expand Up @@ -194,13 +191,7 @@ case class GlutenHashAggregateExecTransformer(
case Final =>
val dataType = aggregateFunction match {
case avg: Average =>
avg.dataType match {
case _ @ GlutenDecimalUtil.Fixed(p, s) =>
// avg.dataType is Decimal(p + 4, s + 4) and sumType is Decimal(p + 10, s)
// we need to get sumType, so p = p - 4 + 10 and s = s - 4
GlutenDecimalUtil.bounded(p - 4 + 10, s - 4)
case _ => DoubleType
}
GlutenDecimalUtil.getAvgSumDataType(avg)
case _ =>
aggregateFunction.dataType
}
Expand Down

0 comments on commit 7c6f73b

Please sign in to comment.