From 5448708648fe4b503903c315f4d15d28ccc3f5fe Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 6 May 2021 23:41:36 +0900 Subject: [PATCH 1/7] Fix --- .../spark/sql/execution/ExpandExec.scala | 60 +++++++++++++------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 3fd653130e57c..24ad90b929cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution +import scala.collection.mutable + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,15 +153,16 @@ case class ExpandExec( // This column is the same across all output rows. Just generate code for it here. BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx) } else { - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val code = code""" - |boolean $isNull = true; - |${CodeGenerator.javaType(firstExpr.dataType)} $value = - | ${CodeGenerator.defaultValue(firstExpr.dataType)}; - """.stripMargin + val isNull = ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, + "resultIsNull", + v => s"$v = true;") + val value = ctx.addMutableState( + CodeGenerator.javaType(firstExpr.dataType), + "resultValue", + v => s"$v = ${CodeGenerator.defaultValue(firstExpr.dataType)};") + ExprCode( - code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, firstExpr.dataType)) } @@ -168,22 +170,42 @@ case class ExpandExec( // Part 2: switch/case statements val cases = projections.zipWithIndex.map { case (exprs, row) => - var updateCode = "" - for (col <- exprs.indices) { + val updateCode = mutable.ArrayBuffer[String]() + exprs.indices.foreach { col => if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx) - updateCode += - s""" - |${ev.code} - |${outputColumns(col).isNull} = ${ev.isNull}; - |${outputColumns(col).value} = ${ev.value}; - """.stripMargin + val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) + val ev = boundExpr.genCode(ctx) + val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1.toSeq + val argList = inputVars.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + } + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) + if (CodeGenerator.isValidParamLength(paramLength)) { + val switchCaseFunc = ctx.freshName("switchCaseCode") + ctx.addNewFunction(switchCaseFunc, + s""" + |private void $switchCaseFunc(${argList.mkString(", ")}) { + | ${ev.code} + | ${outputColumns(col).isNull} = ${ev.isNull}; + | ${outputColumns(col).value} = ${ev.value}; + |} + """.stripMargin) + + updateCode += s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});" + } else { + updateCode += + s""" + |${ev.code} + |${outputColumns(col).isNull} = ${ev.isNull}; + |${outputColumns(col).value} = ${ev.value}; + """.stripMargin + } } } s""" |case $row: - | ${updateCode.trim} + | ${updateCode.mkString("\n")} | break; """.stripMargin } From cb182b888439d3efe1e46aa0aa44fb1ede96ff8f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 7 May 2021 09:12:48 +0900 Subject: [PATCH 2/7] review --- .../scala/org/apache/spark/sql/execution/ExpandExec.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 24ad90b929cdf..cac2774e3532d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -176,12 +176,12 @@ case class ExpandExec( val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) val ev = boundExpr.genCode(ctx) val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1.toSeq - val argList = inputVars.map { v => - s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" - } val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) if (CodeGenerator.isValidParamLength(paramLength)) { val switchCaseFunc = ctx.freshName("switchCaseCode") + val argList = inputVars.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + } ctx.addNewFunction(switchCaseFunc, s""" |private void $switchCaseFunc(${argList.mkString(", ")}) { From ba0945d9184c2d379a42f968b85c20f373062aea Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 7 May 2021 22:31:02 +0900 Subject: [PATCH 3/7] Review --- .../scala/org/apache/spark/sql/execution/ExpandExec.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index cac2774e3532d..2710a40654538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf /** * Apply all of the GroupExpressions to every input row, hence we will get @@ -176,8 +177,9 @@ case class ExpandExec( val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) val ev = boundExpr.genCode(ctx) val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1.toSeq + val splitThreshold = SQLConf.get.methodSplitThreshold val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) - if (CodeGenerator.isValidParamLength(paramLength)) { + if (ev.code.length > splitThreshold && CodeGenerator.isValidParamLength(paramLength)) { val switchCaseFunc = ctx.freshName("switchCaseCode") val argList = inputVars.map { v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" From 18e87d252da2dcc1d2cc5ab36163ad1c3139dea4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 7 May 2021 23:48:04 +0900 Subject: [PATCH 4/7] Review --- .../spark/sql/execution/ExpandExec.scala | 68 ++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 2710a40654538..9c427fe419ef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.mutable - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -171,43 +169,49 @@ case class ExpandExec( // Part 2: switch/case statements val cases = projections.zipWithIndex.map { case (exprs, row) => - val updateCode = mutable.ArrayBuffer[String]() - exprs.indices.foreach { col => + val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) - val ev = boundExpr.genCode(ctx) - val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1.toSeq - val splitThreshold = SQLConf.get.methodSplitThreshold - val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) - if (ev.code.length > splitThreshold && CodeGenerator.isValidParamLength(paramLength)) { - val switchCaseFunc = ctx.freshName("switchCaseCode") - val argList = inputVars.map { v => - s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" - } - ctx.addNewFunction(switchCaseFunc, - s""" - |private void $switchCaseFunc(${argList.mkString(", ")}) { - | ${ev.code} - | ${outputColumns(col).isNull} = ${ev.isNull}; - | ${outputColumns(col).value} = ${ev.value}; - |} - """.stripMargin) - - updateCode += s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});" - } else { - updateCode += - s""" - |${ev.code} - |${outputColumns(col).isNull} = ${ev.isNull}; - |${outputColumns(col).value} = ${ev.value}; - """.stripMargin - } + val exprCode = boundExpr.genCode(ctx) + val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1 + Some(((col, exprCode), inputVars)) + } else { + None + } + }.unzip + + val updateCode = exprCodesWithIndices.map { case (col, ev) => + s""" + |${ev.code} + |${outputColumns(col).isNull} = ${ev.isNull}; + |${outputColumns(col).value} = ${ev.value}; + """.stripMargin + } + + val splitThreshold = SQLConf.get.methodSplitThreshold + val inputVars = inputVarSets.reduce(_ ++ _) + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars.toSeq) + val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength) && + exprCodesWithIndices.map(_._2.code.length).sum > splitThreshold) { + val switchCaseFunc = ctx.freshName("switchCaseCode") + val argList = inputVars.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" } + ctx.addNewFunction(switchCaseFunc, + s""" + |private void $switchCaseFunc(${argList.mkString(", ")}) { + | ${updateCode.mkString("\n")} + |} + """.stripMargin) + + s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});" + } else { + updateCode.mkString("\n") } s""" |case $row: - | ${updateCode.mkString("\n")} + | $maybeSplitUpdateCode | break; """.stripMargin } From a1873d906dbc660ebd4d31853732f64c7c0231b7 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 8 May 2021 08:41:49 +0900 Subject: [PATCH 5/7] Fix --- .../main/scala/org/apache/spark/sql/execution/ExpandExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 9c427fe419ef3..b319ab3c6dd98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -189,7 +189,7 @@ case class ExpandExec( } val splitThreshold = SQLConf.get.methodSplitThreshold - val inputVars = inputVarSets.reduce(_ ++ _) + val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _) val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars.toSeq) val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength) && exprCodesWithIndices.map(_._2.code.length).sum > splitThreshold) { From b01bb8ce8dec363541efd17c336b109c40706535 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 11 May 2021 09:20:50 +0900 Subject: [PATCH 6/7] review --- .../spark/sql/execution/ExpandExec.scala | 69 +++++++++++-------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index b319ab3c6dd98..a3466232a1086 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -168,7 +168,7 @@ case class ExpandExec( } // Part 2: switch/case statements - val cases = projections.zipWithIndex.map { case (exprs, row) => + val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) => val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) @@ -180,40 +180,55 @@ case class ExpandExec( } }.unzip - val updateCode = exprCodesWithIndices.map { case (col, ev) => + val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _) + (row, exprCodesWithIndices, inputVars.toSeq) + } + + def generateUpdateCode(exprCodes: Seq[(Int, ExprCode)]): String = { + exprCodes.map { case (col, ev) => s""" |${ev.code} |${outputColumns(col).isNull} = ${ev.isNull}; |${outputColumns(col).value} = ${ev.value}; """.stripMargin - } + }.mkString("\n") + } - val splitThreshold = SQLConf.get.methodSplitThreshold - val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _) - val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars.toSeq) - val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength) && - exprCodesWithIndices.map(_._2.code.length).sum > splitThreshold) { - val switchCaseFunc = ctx.freshName("switchCaseCode") - val argList = inputVars.map { v => - s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + val splitThreshold = SQLConf.get.methodSplitThreshold + val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) { + switchCaseExprs.map { case (row, exprCodes, inputVars) => + val updateCode = generateUpdateCode(exprCodes) + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) + val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength)) { + val switchCaseFunc = ctx.freshName("switchCaseCode") + val argList = inputVars.map { v => + s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}" + } + ctx.addNewFunction(switchCaseFunc, + s""" + |private void $switchCaseFunc(${argList.mkString(", ")}) { + | $updateCode + |} + """.stripMargin) + + s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});" + } else { + updateCode } - ctx.addNewFunction(switchCaseFunc, - s""" - |private void $switchCaseFunc(${argList.mkString(", ")}) { - | ${updateCode.mkString("\n")} - |} - """.stripMargin) - - s"$switchCaseFunc(${inputVars.map(_.variableName).mkString(", ")});" - } else { - updateCode.mkString("\n") + s""" + |case $row: + | $maybeSplitUpdateCode + | break; + """.stripMargin + } + } else { + switchCaseExprs.map { case (row, exprCodes, _) => + s""" + |case $row: + | ${generateUpdateCode(exprCodes)} + | break; + """.stripMargin } - - s""" - |case $row: - | $maybeSplitUpdateCode - | break; - """.stripMargin } val numOutput = metricTerm(ctx, "numOutputRows") From 9e74ca601529fa27cc1ac55aa015430fce0990a1 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 12 May 2021 12:55:19 +0900 Subject: [PATCH 7/7] review --- .../org/apache/spark/sql/execution/ExpandExec.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a3466232a1086..c087fdf5f962b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -184,7 +184,7 @@ case class ExpandExec( (row, exprCodesWithIndices, inputVars.toSeq) } - def generateUpdateCode(exprCodes: Seq[(Int, ExprCode)]): String = { + val updateCodes = switchCaseExprs.map { case (_, exprCodes, _) => exprCodes.map { case (col, ev) => s""" |${ev.code} @@ -196,8 +196,7 @@ case class ExpandExec( val splitThreshold = SQLConf.get.methodSplitThreshold val cases = if (switchCaseExprs.flatMap(_._2.map(_._2.code.length)).sum > splitThreshold) { - switchCaseExprs.map { case (row, exprCodes, inputVars) => - val updateCode = generateUpdateCode(exprCodes) + switchCaseExprs.zip(updateCodes).map { case ((row, _, inputVars), updateCode) => val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars) val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength)) { val switchCaseFunc = ctx.freshName("switchCaseCode") @@ -222,10 +221,10 @@ case class ExpandExec( """.stripMargin } } else { - switchCaseExprs.map { case (row, exprCodes, _) => + switchCaseExprs.map(_._1).zip(updateCodes).map { case (row, updateCode) => s""" |case $row: - | ${generateUpdateCode(exprCodes)} + | $updateCode | break; """.stripMargin }