Skip to content

Commit

Permalink
[GLUTEN-8772][CORE] refactor: Refactoring the use of SubstraitContext…
Browse files Browse the repository at this point in the history
…#functionMap
  • Loading branch information
wypb committed Feb 25, 2025
1 parent 321a602 commit fdf1727
Show file tree
Hide file tree
Showing 46 changed files with 207 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.gluten.expression.ExpressionNames.MONOTONICALLY_INCREASING_ID
import org.apache.gluten.extension.ExpressionExtensionTrait
import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}
import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
import org.apache.gluten.vectorized.CHColumnarBatchSerializer
Expand Down Expand Up @@ -58,8 +59,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.commons.lang3.ClassUtils

import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import java.util.{ArrayList => JArrayList, List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -703,7 +703,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
windowExpression: Seq[NamedExpression],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
args: JMap[String, JLong]): Unit = {
context: SubstraitContext): Unit = {

windowExpression.map {
windowExpr =>
Expand All @@ -715,7 +715,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, aggWindowFunc).toInt,
WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
Expand All @@ -739,10 +739,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(expr, originalInputAttributes)
.doTransform(args)))
.doTransform(context)))

val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
CHExpressions.createAggregateFunction(args, aggExpression.aggregateFunction).toInt,
CHExpressions.createAggregateFunction(context, aggExpression.aggregateFunction).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable),
Expand Down Expand Up @@ -778,21 +778,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
.replaceWithExpressionTransformer(
offsetWf.input,
attributeSeq = originalInputAttributes)
.doTransform(args))
.doTransform(context))
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.offset,
attributeSeq = originalInputAttributes)
.doTransform(args))
.doTransform(context))
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.default,
attributeSeq = originalInputAttributes)
.doTransform(args))
.doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, offsetWf).toInt,
WindowFunctionsBuilder.create(context, offsetWf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
Expand All @@ -806,9 +806,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
val literal = buckets.asInstanceOf[Literal]
childrenNodeList.add(LiteralTransformer(literal).doTransform(args))
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, wf).toInt,
WindowFunctionsBuilder.create(context, wf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.clickhouse
import org.apache.gluten.backendsapi.TransformerApi
import org.apache.gluten.execution.{CHHashAggregateExecTransformer, WriteFilesExecTransformer}
import org.apache.gluten.expression.ConverterUtils
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{BooleanLiteralNode, ExpressionBuilder, ExpressionNode}
import org.apache.gluten.utils.{CHInputPartitionsUtil, ExpressionDocUtil}

Expand Down Expand Up @@ -211,16 +212,14 @@ class CHTransformerApi extends TransformerApi with Logging {
}

override def createCheckOverflowExprNode(
args: java.lang.Object,
context: SubstraitContext,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionId = ExpressionBuilder.newScalarFunction(
functionMap,
val functionId = context.registerFunction(
ConverterUtils.makeFuncName(
substraitExprName,
Seq(dataType, BooleanType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class CHValidatorApi extends ValidatorApi with AdaptiveSparkPlanHelper with Logg
expr =>
val node = ExpressionConverter
.replaceWithExpressionTransformer(expr, outputAttributes)
.doTransform(substraitContext.registeredFunction)
.doTransform(substraitContext)
node.isInstanceOf[SelectionNode]
}
if (allSelectionNodes || supportShuffleWithProject(outputPartitioning, child)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,12 @@ case class CHAggregateGroupLimitExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
.doTransform(args))
.doTransform(context))
.asJava

// Sort By Expressions
Expand All @@ -102,7 +101,7 @@ case class CHAggregateGroupLimitExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq = child.output)
.doTransform(args)
.doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ case class CHHashAggregateExecTransformer(
operatorId: Long,
input: RelNode = null,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Get the grouping nodes.
val groupingList = new util.ArrayList[ExpressionNode]()
groupingExpressions.foreach(
Expand All @@ -247,7 +246,7 @@ case class CHHashAggregateExecTransformer(
// may be different for each backend.
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
.doTransform(context)
groupingList.add(exprNode)
})
// Get the aggregate function nodes.
Expand All @@ -267,7 +266,7 @@ case class CHHashAggregateExecTransformer(
if (aggExpr.filter.isDefined) {
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.filter.get, childOutput)
.doTransform(args)
.doTransform(context)
aggFilterList.add(exprNode)
} else {
aggFilterList.add(null)
Expand All @@ -281,7 +280,7 @@ case class CHHashAggregateExecTransformer(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, childOutput)
.doTransform(args)
.doTransform(context)
})

val extraNodes = aggregateFunc match {
Expand All @@ -290,7 +289,7 @@ case class CHHashAggregateExecTransformer(
Seq(
ExpressionConverter
.replaceWithExpressionTransformer(relativeSDLiteral, child.output)
.doTransform(args))
.doTransform(context))
case _ => Seq.empty
}

Expand All @@ -311,20 +310,20 @@ case class CHHashAggregateExecTransformer(
child.asInstanceOf[BaseAggregateExec].groupingExpressions,
child.asInstanceOf[BaseAggregateExec].aggregateExpressions)
)
Seq(aggTypesExpr.doTransform(args))
Seq(aggTypesExpr.doTransform(context))
case Final | PartialMerge =>
Seq(
ExpressionConverter
.replaceWithExpressionTransformer(aggExpr.resultAttribute, originalInputAttributes)
.doTransform(args))
.doTransform(context))
case other =>
throw new GlutenNotSupportException(s"$other not supported.")
}
for (node <- childrenNodes) {
childrenNodeList.add(node)
}
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
CHExpressions.createAggregateFunction(args, aggregateFunc),
CHExpressions.createAggregateFunction(context, aggregateFunc),
childrenNodeList,
modeToKeyWord(aggExpr.mode),
ConverterUtils.getTypeNode(aggregateFunc.dataType, aggregateFunc.nullable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,12 @@ case class CHWindowGroupLimitExecTransformer(
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
.doTransform(args))
.doTransform(context))
.asJava

// Sort By Expressions
Expand All @@ -112,7 +111,7 @@ case class CHWindowGroupLimitExecTransformer(
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq = child.output)
.doTransform(args)
.doTransform(context)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
Expand Down
Loading

0 comments on commit fdf1727

Please sign in to comment.