diff --git a/jvm/src/main/scala/io/glutenproject/GlutenPlugin.scala b/jvm/src/main/scala/io/glutenproject/GlutenPlugin.scala index f1c8173778bf..0016d5e23d2d 100644 --- a/jvm/src/main/scala/io/glutenproject/GlutenPlugin.scala +++ b/jvm/src/main/scala/io/glutenproject/GlutenPlugin.scala @@ -24,7 +24,7 @@ import scala.language.implicitConversions import com.google.protobuf.Any import io.glutenproject.GlutenPlugin.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY} import io.glutenproject.backendsapi.BackendsApiManager -import io.glutenproject.extension.{ColumnarOverrides, OthersExtensionOverrides, StrategyOverrides} +import io.glutenproject.extension.{ColumnarOverrides, ColumnarQueryStagePreparations, OthersExtensionOverrides, StrategyOverrides} import io.glutenproject.substrait.expression.ExpressionBuilder import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.plan.PlanBuilder @@ -118,6 +118,7 @@ private[glutenproject] object GlutenPlugin { * Specify all injectors that Gluten is using in following list. */ val DEFAULT_INJECTORS: List[GlutenSparkExtensionsInjector] = List( + ColumnarQueryStagePreparations, ColumnarOverrides, StrategyOverrides, OthersExtensionOverrides diff --git a/jvm/src/main/scala/io/glutenproject/execution/ShuffledHashJoinExecTransformer.scala b/jvm/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala similarity index 99% rename from jvm/src/main/scala/io/glutenproject/execution/ShuffledHashJoinExecTransformer.scala rename to jvm/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala index d376436f18b1..da4860d2a798 100644 --- a/jvm/src/main/scala/io/glutenproject/execution/ShuffledHashJoinExecTransformer.scala +++ b/jvm/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala @@ -167,8 +167,7 @@ abstract class HashJoinLikeExecTransformer( if (GlutenConfig.getConf.enableNativeValidation) { val validator = new ExpressionEvaluator() val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(relNode)) - val result = validator.doValidate(planNode.toProtobuf.toByteArray) - result + validator.doValidate(planNode.toProtobuf.toByteArray) } else { true } diff --git a/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index a8b23ea72155..25ffc5874627 100644 --- a/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -46,10 +46,20 @@ case class TransformPreOverrides() extends Rule[SparkPlan] { def replaceWithTransformerPlan(plan: SparkPlan): SparkPlan = plan match { case RowGuard(child: CustomShuffleReaderExec) => replaceWithTransformerPlan(child) + case RowGuard(bhj: BroadcastHashJoinExec) => + bhj.withNewChildren(bhj.children.map { + // ResuedExchange is not created yet, so we don't need to handle that case. + case e: BroadcastExchangeExec => + replaceWithTransformerPlan(RowGuard(e)) + case other => replaceWithTransformerPlan(other) + }) case plan: RowGuard => val actualPlan = plan.child logDebug(s"Columnar Processing for ${actualPlan.getClass} is under RowGuard.") actualPlan.withNewChildren(actualPlan.children.map(replaceWithTransformerPlan)) + case plan if plan.getTagValue(RowGuardTag.key).contains(true) => + // Add RowGuard if the plan has a RowGuardTag. + replaceWithTransformerPlan(RowGuard(plan)) /* case plan: ArrowEvalPythonExec => val columnarChild = replaceWithTransformerPlan(plan.child) ArrowEvalPythonExecTransformer(plan.udfs, plan.resultAttrs, columnarChild, plan.evalType) */ @@ -281,7 +291,6 @@ case class TransformPostOverrides() extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { replaceWithTransformerPlan(plan) } - } case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging { diff --git a/jvm/src/main/scala/io/glutenproject/extension/ColumnarQueryStagePrepRule.scala b/jvm/src/main/scala/io/glutenproject/extension/ColumnarQueryStagePrepRule.scala new file mode 100644 index 000000000000..81478b2f5d51 --- /dev/null +++ b/jvm/src/main/scala/io/glutenproject/extension/ColumnarQueryStagePrepRule.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.glutenproject.extension + +import io.glutenproject.{GlutenConfig, GlutenSparkExtensionsInjector} +import io.glutenproject.execution.BroadcastHashJoinExecTransformer + +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec + +// RowGuardTag is useful to transform the plan and add guard tag before creating new QueryStages. +// +// e.g. BroadcastHashJoinExec and it's child BroadcastExec will be cut into different QueryStages, +// so the columnar rules will be applied to the two QueryStages separately, and they cannot +// see each other during transformation. In order to prevent BroadcastExec being transformed +// to columnar while BHJ fallbacks, we can add RowGuardTag to BroadcastExec when applying +// queryStagePrepRules and check the tag when applying columnarRules. +// RowGuardTag will be ignored if the plan is already guarded by RowGuard. +object RowGuardTag { + val key: TreeNodeTag[Boolean] = TreeNodeTag[Boolean]("RowGuard") + val value: Boolean = true +} + +case class ColumnarQueryStagePrepRule(session: SparkSession) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val columnarConf: GlutenConfig = GlutenConfig.getSessionConf + plan.transformDown { + case bhj: BroadcastHashJoinExec => + if (columnarConf.enableColumnarBroadcastExchange && + columnarConf.enableColumnarBroadcastJoin) { + val transformer = BroadcastHashJoinExecTransformer( + bhj.leftKeys, + bhj.rightKeys, + bhj.joinType, + bhj.buildSide, + bhj.condition, + bhj.left, + bhj.right, + bhj.isNullAwareAntiJoin) + if (!transformer.doValidate()) { + bhj.children.map { + // ResuedExchange is not created yet, so we don't need to handle that case. + case e: BroadcastExchangeExec => AddRowGuardTag(e) + case plan => plan + } + } + } + bhj + case plan => plan + } + } + + def AddRowGuardTag(plan: SparkPlan): SparkPlan = { + plan.setTagValue(RowGuardTag.key, RowGuardTag.value) + plan + } +} + +object ColumnarQueryStagePreparations extends GlutenSparkExtensionsInjector { + override def inject(extensions: SparkSessionExtensions): Unit = { + extensions.injectQueryStagePrepRule(ColumnarQueryStagePrepRule) + } +} + diff --git a/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala b/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala index 84fd92c00733..5968ad613b10 100644 --- a/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala +++ b/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala @@ -66,8 +66,10 @@ case class TransformGuardRule() extends Rule[SparkPlan] { val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand val enableColumnarShuffledHashJoin: Boolean = columnarConf.enableColumnarShuffledHashJoin - val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange - val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin + val enableColumnarBroadcastExchange: Boolean = + columnarConf.enableColumnarBroadcastJoin && columnarConf.enableColumnarBroadcastExchange + val enableColumnarBroadcastJoin: Boolean = + columnarConf.enableColumnarBroadcastJoin && columnarConf.enableColumnarBroadcastExchange val enableColumnarArrowUDF: Boolean = columnarConf.enableColumnarArrowUDF def apply(plan: SparkPlan): SparkPlan = { @@ -154,7 +156,7 @@ case class TransformGuardRule() extends Rule[SparkPlan] { transformer.doValidate() case plan: BroadcastExchangeExec => // columnar broadcast is enabled only when columnar bhj is enabled. - if (!enableColumnarBroadcastJoin) return false + if (!enableColumnarBroadcastExchange) return false val exec = ColumnarBroadcastExchangeExec(plan.mode, plan.child) exec.doValidate() case plan: BroadcastHashJoinExec =>