Skip to content

Commit

Permalink
DPP should considered as selective predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Apr 28, 2024
1 parent eba6364 commit a40d6cc
Show file tree
Hide file tree
Showing 15 changed files with 464 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,49 @@ trait PredicateHelper extends AliasHelper with Logging {
case _: MultiLikeBase => true
case _ => false
}

/**
* get the distinct counts of an attribute for a given table
*/
def distinctCounts(attr: Attribute, plan: LogicalPlan): Option[BigInt] = {
plan.stats.attributeStats.get(attr).flatMap(_.distinctCount)
}

/**
* We estimate the filtering ratio using column statistics if they are available, otherwise we
* use the config value of `spark.sql.optimizer.dynamicPartitionPruning.fallbackFilterRatio`.
*/
def estimateFilteringRatio(
partExpr: Expression,
partPlan: LogicalPlan,
otherExpr: Expression,
otherPlan: LogicalPlan,
conf: SQLConf): Double = {

// the default filtering ratio when CBO stats are missing, but there is a
// predicate that is likely to be selective
val fallbackRatio = conf.dynamicPartitionPruningFallbackFilterRatio
// the filtering ratio based on the type of the join condition and on the column statistics
(partExpr.references.toList, otherExpr.references.toList) match {
// filter out expressions with more than one attribute on any side of the operator
case (leftAttr :: Nil, rightAttr :: Nil)
if conf.dynamicPartitionPruningUseStats =>
// get the CBO stats for each attribute in the join condition
val partDistinctCount = distinctCounts(leftAttr, partPlan)
val otherDistinctCount = distinctCounts(rightAttr, otherPlan)
val availableStats = partDistinctCount.isDefined && partDistinctCount.get > 0 &&
otherDistinctCount.isDefined
if (!availableStats) {
fallbackRatio
} else if (partDistinctCount.get.toDouble <= otherDistinctCount.get.toDouble) {
// there is likely an estimation error, so we fallback
fallbackRatio
} else {
1 - otherDistinctCount.get.toDouble / partDistinctCount.get.toDouble
}
case _ => fallbackRatio
}
}
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_FILTER_EXPRESSION, RUNTIME_FILTER_SUBQUERY, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{BinaryType, DataType}

/**
* The RuntimeFilterSubquery expression is only used in runtime filter. It is inserted in cases
* when broadcast exchange can be reused.
*
* @param filterApplicationSideKey the filtering key of the application side.
* @param filterCreationSidePlan the build side of the join.
* @param filterCreationSideKey the key of the creation side.
*/
case class RuntimeFilterSubquery(
filterApplicationSideKey: Expression,
filterCreationSidePlan: LogicalPlan,
filterCreationSideKey: Expression,
exprId: ExprId = NamedExpression.newExprId,
hint: Option[HintInfo] = None)
extends SubqueryExpression(
filterCreationSidePlan, Seq(filterApplicationSideKey), exprId, Seq.empty, hint)
with Unevaluable
with UnaryLike[Expression] {

override def child: Expression = filterApplicationSideKey

override def dataType: DataType = BinaryType

override def plan: LogicalPlan = filterCreationSidePlan

override def nullable: Boolean = false

override def withNewPlan(plan: LogicalPlan): RuntimeFilterSubquery =
copy(filterCreationSidePlan = plan)

override def withNewOuterAttrs(outerAttrs: Seq[Expression]): RuntimeFilterSubquery = {
// Updating outer attrs of RuntimeFilterSubquery is unsupported; assert that they match
// filterApplicationSideKey and return a copy without any changes.
assert(outerAttrs.size == 1 && outerAttrs.head.semanticEquals(filterApplicationSideKey))
copy()
}

override def withNewHint(hint: Option[HintInfo]): SubqueryExpression = copy(hint = hint)

override lazy val resolved: Boolean =
filterApplicationSideKey.resolved &&
filterCreationSidePlan.resolved &&
filterCreationSideKey.resolved

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_FILTER_SUBQUERY)

override def toString: String = s"runtimefilter#${exprId.id} $conditionString"

override lazy val canonicalized: RuntimeFilterSubquery = {
copy(
filterApplicationSideKey = filterApplicationSideKey.canonicalized,
filterCreationSidePlan = filterCreationSidePlan.canonicalized,
filterCreationSideKey = filterCreationSideKey.canonicalized,
exprId = ExprId(0))
}

override protected def withNewChildInternal(newChild: Expression): RuntimeFilterSubquery =
copy(filterApplicationSideKey = newChild)
}

/**
* Marker for a planned runtime filter expression.
* The expression is created during planning, and it defers to its child for evaluation.
*
* @param child underlying aggregate for runtime filter.
*/
case class RuntimeFilterExpression(child: Expression)
extends UnaryExpression {
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = child.eval(input)
final override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_FILTER_EXPRESSION)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx)
}

override protected def withNewChildInternal(newChild: Expression): RuntimeFilterExpression =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.optimizer

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
Expand Down Expand Up @@ -49,14 +47,45 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
)
}

private def hasSelectiveDynamicPruningSubquery(plan: LogicalPlan): Boolean = {
plan.find {
case Filter(condition, partPlan) =>
val ratios = splitConjunctivePredicates(condition).collect {
case DynamicPruningSubquery(pruningKey, buildPlan, buildKeys, indices, _, _, _) =>
require(indices.size == 1, "DPP Filters should only have a single broadcasting key " +
"since there are no usage for multiple broadcasting keys at the moment.")
val buildKey = buildKeys(indices.head)
val filterRatio =
estimateFilteringRatio(pruningKey, partPlan, buildKey, buildPlan, conf)
1 - filterRatio
}

if (ratios.isEmpty) {
false
} else {
val finalRatio = ratios.reduce(_ * _)
finalRatio * partPlan.stats.sizeInBytes.toDouble <=
conf.runtimeFilterCreationSideThreshold
}
case _ => false
}.isDefined
}

private def injectBloomFilter(
filterApplicationSideKey: Expression,
filterApplicationSidePlan: LogicalPlan,
filterCreationSideKey: Expression,
filterCreationSidePlan: LogicalPlan): LogicalPlan = {
// Skip if the filter creation side is too big
if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) {
return filterApplicationSidePlan
filterCreationSidePlan match {
case ProjectAdapter(_, child) =>
if (!hasSelectiveDynamicPruningSubquery(child)) {
return filterApplicationSidePlan
}
case _ =>
if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) {
return filterApplicationSidePlan
}
}
val rowCount = filterCreationSidePlan.stats.rowCount
val bloomFilterAgg =
Expand All @@ -69,7 +98,13 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
val bloomFilterSubquery = filterCreationSidePlan match {
case _: ProjectAdapter =>
// Try to reuse the results of exchange.
RuntimeFilterSubquery(filterApplicationSideKey, aggregate, filterCreationSideKey)
case _ =>
ScalarSubquery(aggregate, Nil)
}
val filter = BloomFilterMightContain(bloomFilterSubquery,
new XxHash64(Seq(filterApplicationSideKey)))
Filter(filter, filterApplicationSidePlan)
Expand Down Expand Up @@ -113,13 +148,23 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
extract(child, predicateReference, hasHitFilter, hasHitSelectiveFilter, currentPlan,
targetKey)
case Filter(condition, child) if isSimpleExpression(condition) =>
extract(
val (dynamicPrunings, otherPredicates) =
splitConjunctivePredicates(condition).partition(_.isInstanceOf[DynamicPruningSubquery])

val existsLikelySelective = otherPredicates.exists(isLikelySelective)
val extracted = extract(
child,
predicateReference ++ condition.references,
hasHitFilter = true,
hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition),
hasHitSelectiveFilter = hasHitSelectiveFilter || existsLikelySelective ||
dynamicPrunings.nonEmpty,
currentPlan,
targetKey)
if (conf.exchangeReuseEnabled && !existsLikelySelective && dynamicPrunings.nonEmpty) {
extracted.map(kv => (kv._1, ProjectAdapter(kv._2.output, kv._2)))
} else {
extracted
}
case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, left, right, _) =>
// Runtime filters use one side of the [[Join]] to build a set of join key values and prune
// the other side of the [[Join]]. It's also OK to use a superset of the join key values
Expand Down Expand Up @@ -231,20 +276,30 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
}

// This checks if there is already a DPP filter, as this rule is called just after DPP.
@tailrec
private def hasDynamicPruningSubquery(
left: LogicalPlan,
right: LogicalPlan,
leftKey: Expression,
rightKey: Expression): Boolean = {
(left, right) match {
case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) =>
pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) =>
pruningKey.fastEquals(rightKey) ||
hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
left.find {
case Filter(condition, plan) =>
splitConjunctivePredicates(condition).exists {
case DynamicPruningSubquery(pruningKey, _, _, _, _, _, _) =>
pruningKey.fastEquals(leftKey) ||
hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
case _ => false
}
case _ => false
}
}.isDefined || right.find {
case Filter(condition, plan) =>
splitConjunctivePredicates(condition).exists {
case DynamicPruningSubquery(pruningKey, _, _, _, _, _, _) =>
pruningKey.fastEquals(rightKey) ||
hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
case _ => false
}
case _ => false
}.isDefined
}

private def hasBloomFilter(plan: LogicalPlan, key: Expression): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ object Subquery {
Subquery(s.plan, SubqueryExpression.hasCorrelatedSubquery(s))
}

case class ProjectAdapter(projectList: Seq[NamedExpression], child: LogicalPlan)
extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)

override protected def withNewChildInternal(newChild: LogicalPlan): ProjectAdapter =
copy(child = newChild)
}

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ object TreePattern extends Enumeration {
val RUNTIME_REPLACEABLE: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SCALAR_SUBQUERY_REFERENCE: Value = Value
val RUNTIME_FILTER_EXPRESSION: Value = Value
val RUNTIME_FILTER_SUBQUERY: Value = Value
val SCALA_UDF: Value = Value
val SESSION_WINDOW: Value = Value
val SORT: Value = Value
val SUBQUERY_ALIAS: Value = Value
val SUBQUERY_WRAPPER: Value = Value
val SUM: Value = Value
val TIME_WINDOW: Value = Value
val TIME_ZONE_AWARE_EXPRESSION: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ class SparkOptimizer(
// twice which may break some optimizer rules that can only be applied once. The rule below
// only invokes `OptimizeSubqueries` to optimize newly added subqueries.
new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
PushDownPredicates) :+
Batch("InjectRuntimeFilter", FixedPoint(1),
InjectRuntimeFilter) :+
Batch("MergeScalarSubqueries", Once,
MergeScalarSubqueries,
RewriteDistinctAggregates) :+
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
Batch("Pushdown Filters from RuntimeFilter", fixedPoint,
PushDownPredicates) :+
Batch("Cleanup filters that cannot be pushed down", Once,
CleanupDynamicPruningFilters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
case logical.Sort(sortExprs, global, child) =>
execution.SortExec(sortExprs, global, planLater(child)) :: Nil
case logical.ProjectAdapter(projectList, child) =>
execution.InputAdapter(planLater(child)) :: Nil
case logical.Project(projectList, child) =>
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ case class AdaptiveSparkPlanExec(
// optimizations should be stage-independent.
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
PlanAdaptiveDynamicPruningFilters(this),
PlanAdaptiveRuntimeFilterFilters(this),
ReuseAdaptiveSubquery(context.subqueryCache),
OptimizeSkewInRebalancePartitions,
CoalesceShufflePartitions(context.session),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.LogKey.{CONFIG, SUB_QUERY}
import org.apache.spark.internal.MDC
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningSubquery, ListQuery, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningSubquery, ListQuery, RuntimeFilterSubquery, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -130,11 +130,13 @@ case class InsertAdaptiveSparkPlan(
*/
private def buildSubqueryMap(plan: SparkPlan): Map[Long, SparkPlan] = {
val subqueryMap = mutable.HashMap.empty[Long, SparkPlan]
if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
if (!plan.containsAnyPattern(
SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY, RUNTIME_FILTER_SUBQUERY)) {
return subqueryMap.toMap
}
plan.foreach(_.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach(_.foreach {
case e @ (_: expressions.ScalarSubquery | _: ListQuery | _: DynamicPruningSubquery) =>
case e @ (_: expressions.ScalarSubquery | _: ListQuery |
_: DynamicPruningSubquery | _: RuntimeFilterSubquery) =>
val subquery = e.asInstanceOf[SubqueryExpression]
if (!subqueryMap.contains(subquery.exprId.id)) {
val executedPlan = compileSubquery(subquery.plan)
Expand Down
Loading

0 comments on commit a40d6cc

Please sign in to comment.