Skip to content

Commit

Permalink
fix final avg decimal output type (apache#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored and JkSelf committed Mar 21, 2023
1 parent 0bf8192 commit ca92969
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,10 @@ abstract class HashAggregateExecBaseTransformer(
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeProjectRel(
input, preExprNodes, extensionNode, context, operatorId, emitStartIndex)

}

// Handle the pure Aggregate after Projection. Both grouping and Aggregate expressions are
Expand Down Expand Up @@ -425,6 +427,7 @@ abstract class HashAggregateExecBaseTransformer(
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeProjectRel(
aggRel, resExprNodes, extensionNode, context, operatorId, emitStartIndex)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

import java.util

import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.utils.GlutenDecimalUtil

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DoubleType, LongType}
import org.apache.spark.sql.types.{DecimalType, DoubleType, LongType}

case class GlutenHashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand Down Expand Up @@ -122,11 +122,13 @@ case class GlutenHashAggregateExecTransformer(
}
}
if (!validation) {
RelBuilder.makeProjectRel(aggRel, expressionNodes, context, operatorId, groupingExpressions.size + aggregateExpressions.size)
RelBuilder.makeProjectRel(aggRel, expressionNodes, context, operatorId,
groupingExpressions.size + aggregateExpressions.size)
} else {
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, getPartialAggOutTypes).toProtobuf))
RelBuilder.makeProjectRel(aggRel, expressionNodes, extensionNode, context, operatorId, groupingExpressions.size + aggregateExpressions.size)
RelBuilder.makeProjectRel(aggRel, expressionNodes, extensionNode, context, operatorId,
groupingExpressions.size + aggregateExpressions.size)
}
}

Expand Down Expand Up @@ -175,11 +177,23 @@ case class GlutenHashAggregateExecTransformer(
getIntermediateTypeNode(aggregateFunction))
aggregateNodeList.add(partialNode)
case Final =>
val dataType = aggregateFunction match {
case avg: Average =>
avg.dataType match {
case _ @ GlutenDecimalUtil.Fixed(p, s) =>
// avg.dataType is Decimal(p + 4, s + 4) and sumType is Decimal(p + 10, s)
// we need to get sumType, so p = p - 4 + 10 and s = s - 4
GlutenDecimalUtil.bounded(p - 4 + 10, s - 4)
case _ => DoubleType
}
case _ =>
aggregateFunction.dataType
}
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
AggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeToKeyWord(aggregateMode),
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
ConverterUtils.getTypeNode(dataType, aggregateFunction.nullable))
aggregateNodeList.add(aggFunctionNode)
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
Expand Down Expand Up @@ -351,7 +365,8 @@ case class GlutenHashAggregateExecTransformer(
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(inputRel, exprNodes, extensionNode, context, operatorId, emitStartIndex)
RelBuilder.makeProjectRel(inputRel, exprNodes, extensionNode,
context, operatorId, emitStartIndex)
}

// Create aggregation rel.
Expand Down

0 comments on commit ca92969

Please sign in to comment.