diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 3fd653130e57c..c087fdf5f962b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,9 +21,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf /** * Apply all of the GroupExpressions to every input row, hence we will get @@ -152,40 +152,82 @@ case class ExpandExec( // This column is the same across all output rows. Just generate code for it here. BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx) } else { - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val code = code""" - |boolean $isNull = true; - |${CodeGenerator.javaType(firstExpr.dataType)} $value = - | ${CodeGenerator.defaultValue(firstExpr.dataType)}; - """.stripMargin + val isNull = ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, + "resultIsNull", + v => s"$v = true;") + val value = ctx.addMutableState( + CodeGenerator.javaType(firstExpr.dataType), + "resultValue", + v => s"$v = ${CodeGenerator.defaultValue(firstExpr.dataType)};") + ExprCode( - code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, firstExpr.dataType)) } } // Part 2: switch/case statements - val cases = projections.zipWithIndex.map { case (exprs, row) => - var updateCode = "" - for (col <- exprs.indices) { + val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) => + val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx) - updateCode += + val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) + val exprCode = boundExpr.genCode(ctx) + val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1 + Some(((col, exprCode), inputVars)) + } else { + None + } + }.unzip + + val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _) + (row, exprCodesWithIndices, inputVars.toSeq) + } + + val updateCodes = switchCaseExprs.map { case (_, exprCodes, _) => + exprCodes.map { case (col, ev) => + s""" + |${ev.code} + |${outputColumns(col).isNull} = ${ev.isNull}; + |${outputColumns(col).value} = ${ev.value}; + """.stripMargin + }.mkString("\n") + } + + val splitThreshold = SQLConf.get.methodSplitThreshold + val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) { + switchCaseExprs.zip(updateCodes).map { case ((row, _, inputVars), updateCode) => + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) + val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength)) { + val switchCaseFunc = ctx.freshName("switchCaseCode") + val argList = inputVars.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + } + ctx.addNewFunction(switchCaseFunc, s""" - |${ev.code} - |${outputColumns(col).isNull} = ${ev.isNull}; - |${outputColumns(col).value} = ${ev.value}; - """.stripMargin + |private void $switchCaseFunc(${argList.mkString(", ")}) { + | $updateCode + |} + """.stripMargin) + + s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});" + } else { + updateCode } + s""" + |case $row: + | $maybeSplitUpdateCode + | break; + """.stripMargin + } + } else { + switchCaseExprs.map(_._1).zip(updateCodes).map { case (row, updateCode) => + s""" + |case $row: + | $updateCode + | break; + """.stripMargin } - - s""" - |case $row: - | ${updateCode.trim} - | break; - """.stripMargin } val numOutput = metricTerm(ctx, "numOutputRows")