Skip to content

Commit

Permalink
bhj fallback (facebookincubator#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Jul 15, 2022
1 parent 8433b68 commit 2c67cad
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 7 deletions.
3 changes: 2 additions & 1 deletion jvm/src/main/scala/io/glutenproject/GlutenPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.language.implicitConversions
import com.google.protobuf.Any
import io.glutenproject.GlutenPlugin.{GLUTEN_SESSION_EXTENSION_NAME, SPARK_SESSION_EXTS_KEY}
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.extension.{ColumnarOverrides, OthersExtensionOverrides, StrategyOverrides}
import io.glutenproject.extension.{ColumnarOverrides, ColumnarQueryStagePreparations, OthersExtensionOverrides, StrategyOverrides}
import io.glutenproject.substrait.expression.ExpressionBuilder
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.plan.PlanBuilder
Expand Down Expand Up @@ -118,6 +118,7 @@ private[glutenproject] object GlutenPlugin {
* Specify all injectors that Gluten is using in following list.
*/
val DEFAULT_INJECTORS: List[GlutenSparkExtensionsInjector] = List(
ColumnarQueryStagePreparations,
ColumnarOverrides,
StrategyOverrides,
OthersExtensionOverrides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ abstract class HashJoinLikeExecTransformer(
if (GlutenConfig.getConf.enableNativeValidation) {
val validator = new ExpressionEvaluator()
val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(relNode))
val result = validator.doValidate(planNode.toProtobuf.toByteArray)
result
validator.doValidate(planNode.toProtobuf.toByteArray)
} else {
true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,20 @@ case class TransformPreOverrides() extends Rule[SparkPlan] {
def replaceWithTransformerPlan(plan: SparkPlan): SparkPlan = plan match {
case RowGuard(child: CustomShuffleReaderExec) =>
replaceWithTransformerPlan(child)
case RowGuard(bhj: BroadcastHashJoinExec) =>
bhj.withNewChildren(bhj.children.map {
// ResuedExchange is not created yet, so we don't need to handle that case.
case e: BroadcastExchangeExec =>
replaceWithTransformerPlan(RowGuard(e))
case other => replaceWithTransformerPlan(other)
})
case plan: RowGuard =>
val actualPlan = plan.child
logDebug(s"Columnar Processing for ${actualPlan.getClass} is under RowGuard.")
actualPlan.withNewChildren(actualPlan.children.map(replaceWithTransformerPlan))
case plan if plan.getTagValue(RowGuardTag.key).contains(true) =>
// Add RowGuard if the plan has a RowGuardTag.
replaceWithTransformerPlan(RowGuard(plan))
/* case plan: ArrowEvalPythonExec =>
val columnarChild = replaceWithTransformerPlan(plan.child)
ArrowEvalPythonExecTransformer(plan.udfs, plan.resultAttrs, columnarChild, plan.evalType) */
Expand Down Expand Up @@ -281,7 +291,6 @@ case class TransformPostOverrides() extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
replaceWithTransformerPlan(plan)
}

}

case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.glutenproject.extension

import io.glutenproject.{GlutenConfig, GlutenSparkExtensionsInjector}
import io.glutenproject.execution.BroadcastHashJoinExecTransformer

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec

// RowGuardTag is useful to transform the plan and add guard tag before creating new QueryStages.
//
// e.g. BroadcastHashJoinExec and it's child BroadcastExec will be cut into different QueryStages,
// so the columnar rules will be applied to the two QueryStages separately, and they cannot
// see each other during transformation. In order to prevent BroadcastExec being transformed
// to columnar while BHJ fallbacks, we can add RowGuardTag to BroadcastExec when applying
// queryStagePrepRules and check the tag when applying columnarRules.
// RowGuardTag will be ignored if the plan is already guarded by RowGuard.
object RowGuardTag {
val key: TreeNodeTag[Boolean] = TreeNodeTag[Boolean]("RowGuard")
val value: Boolean = true
}

case class ColumnarQueryStagePrepRule(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
val columnarConf: GlutenConfig = GlutenConfig.getSessionConf
plan.transformDown {
case bhj: BroadcastHashJoinExec =>
if (columnarConf.enableColumnarBroadcastExchange &&
columnarConf.enableColumnarBroadcastJoin) {
val transformer = BroadcastHashJoinExecTransformer(
bhj.leftKeys,
bhj.rightKeys,
bhj.joinType,
bhj.buildSide,
bhj.condition,
bhj.left,
bhj.right,
bhj.isNullAwareAntiJoin)
if (!transformer.doValidate()) {
bhj.children.map {
// ResuedExchange is not created yet, so we don't need to handle that case.
case e: BroadcastExchangeExec => AddRowGuardTag(e)
case plan => plan
}
}
}
bhj
case plan => plan
}
}

def AddRowGuardTag(plan: SparkPlan): SparkPlan = {
plan.setTagValue(RowGuardTag.key, RowGuardTag.value)
plan
}
}

object ColumnarQueryStagePreparations extends GlutenSparkExtensionsInjector {
override def inject(extensions: SparkSessionExtensions): Unit = {
extensions.injectQueryStagePrepRule(ColumnarQueryStagePrepRule)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ case class TransformGuardRule() extends Rule[SparkPlan] {
val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion
val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand
val enableColumnarShuffledHashJoin: Boolean = columnarConf.enableColumnarShuffledHashJoin
val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange
val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin
val enableColumnarBroadcastExchange: Boolean =
columnarConf.enableColumnarBroadcastJoin && columnarConf.enableColumnarBroadcastExchange
val enableColumnarBroadcastJoin: Boolean =
columnarConf.enableColumnarBroadcastJoin && columnarConf.enableColumnarBroadcastExchange
val enableColumnarArrowUDF: Boolean = columnarConf.enableColumnarArrowUDF

def apply(plan: SparkPlan): SparkPlan = {
Expand Down Expand Up @@ -154,7 +156,7 @@ case class TransformGuardRule() extends Rule[SparkPlan] {
transformer.doValidate()
case plan: BroadcastExchangeExec =>
// columnar broadcast is enabled only when columnar bhj is enabled.
if (!enableColumnarBroadcastJoin) return false
if (!enableColumnarBroadcastExchange) return false
val exec = ColumnarBroadcastExchangeExec(plan.mode, plan.child)
exec.doValidate()
case plan: BroadcastHashJoinExec =>
Expand Down

0 comments on commit 2c67cad

Please sign in to comment.