Skip to content

Commit

Permalink
support decimal sum agg (apache#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 authored and JkSelf committed Mar 23, 2023
1 parent 3f315c5 commit d17b23a
Showing 1 changed file with 64 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, DoubleType, LongType}
import org.apache.spark.sql.types.{BooleanType, DecimalType, DoubleType, LongType}

case class GlutenHashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
Expand Down Expand Up @@ -68,6 +68,12 @@ case class GlutenHashAggregateExecTransformer(
return true
case _ =>
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expr.mode match {
case Partial =>
return true
case _ =>
}
case _ =>
}
}
Expand Down Expand Up @@ -116,6 +122,12 @@ case class GlutenHashAggregateExecTransformer(
// Select m2 from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
colIdx += 1
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
// Select sum from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
// Select isEmpty from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
colIdx += 1
case _ =>
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
Expand Down Expand Up @@ -152,6 +164,9 @@ case class GlutenHashAggregateExecTransformer(
structTypeNodes.add(ConverterUtils.getTypeNode(LongType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(DoubleType, nullable = true))
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
structTypeNodes.add(ConverterUtils.getTypeNode(sum.dataType, nullable = true))
structTypeNodes.add(ConverterUtils.getTypeNode(BooleanType, nullable = false))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
Expand Down Expand Up @@ -198,6 +213,25 @@ case class GlutenHashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
aggregateMode match {
case Partial =>
val partialNode = ExpressionBuilder.makeAggregateFunction(
AggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeToKeyWord(aggregateMode),
getIntermediateTypeNode(aggregateFunction))
aggregateNodeList.add(partialNode)
case Final =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
AggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeToKeyWord(aggregateMode),
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
aggregateNodeList.add(aggFunctionNode)
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _ =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
AggregateFunctionsBuilder.create(args, aggregateFunction),
Expand Down Expand Up @@ -230,6 +264,16 @@ case class GlutenHashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expression.mode match {
case Partial =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _ =>
typeNodeList.add(
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
Expand Down Expand Up @@ -341,6 +385,22 @@ case class GlutenHashAggregateExecTransformer(
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
aggregateExpression.mode match {
case Final =>
assert(functionInputAttributes.size == 2,
"Final stage of Average expects two input attributes.")
// Use a Velox function to combine the intermediate columns into struct.
val childNodes = new util.ArrayList[ExpressionNode](
functionInputAttributes.toList.map(attr => {
ExpressionConverter
.replaceWithExpressionTransformer(attr, originalInputAttributes)
.doTransform(args)
}).asJava)
exprNodes.add(getRowConstructNode(args, childNodes, functionInputAttributes))
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
case _ =>
assert(functionInputAttributes.size == 1, "Only one input attribute is expected.")
val childNodes = new util.ArrayList[ExpressionNode](
Expand Down Expand Up @@ -395,6 +455,9 @@ case class GlutenHashAggregateExecTransformer(
// by previous projection.
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
colIdx += 1
case _ =>
aggregateFunc.inputAggBufferAttributes.toList.map(_ => {
childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
Expand Down

0 comments on commit d17b23a

Please sign in to comment.