Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2205] [SQL] Avoid unnecessary exchange operators in multi-way joins #7773

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi
/**
* Represents data where tuples have been ordered according to the `ordering`
* [[Expression Expressions]]. This is a strictly stronger guarantee than
* [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for
* the ordering expressions are contiguous and will never be split across partitions.
* [[ClusteredDistribution]] as an ordering will ensure that tuples that share the
* same value for the ordering expressions are contiguous and will never be split across
* partitions.
*/
case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
require(
Expand All @@ -86,8 +87,12 @@ sealed trait Partitioning {
*/
def satisfies(required: Distribution): Boolean

/** Returns the expressions that are used to key the partitioning. */
def keyExpressions: Seq[Expression]
/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems keyExpressions is not used at all and I do not remember when we added it. So, I am removing it.

* Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One clarification: how does this property differ from compatibleWith? I just want to better understand the precise relationships between satisfies, compatibleWith, and guarantees.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This confusion has been significantly lessened by the removal of compatibleWith in my other patch.

* guarantees the same partitioning scheme described by `other`.
*/
// TODO: Add an example once we have the `nullSafe` concept.
def guarantees(other: Partitioning): Boolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think semanticEqual is a better name? I think this method is basically doing a equality check. For example, if other is a HashPartitioning('a :: Nil, 10) and this is a SinglePartition. We probably do not want to return true because the parent of this operator can be a join and the sibling of this operator can be HashPartitioned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should only consider a name like semanticEqual or semanticEquiv if a.guarantees(b) implies b.guarantees(a) and vice-versa.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, makes sense. Then, semanticEqual is not a good name because once we have the concept of nullSafe. This method will not have the commutative property because nullSafe hash partitioning can be treated as nullUnsafe hash partitioning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Let's leave this for now.

}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
Expand All @@ -96,23 +101,29 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
override def guarantees(other: Partitioning): Boolean = false
}

case object SinglePartition extends Partitioning {
val numPartitions = 1

override def satisfies(required: Distribution): Boolean = true

override def keyExpressions: Seq[Expression] = Nil
override def guarantees(other: Partitioning): Boolean = other match {
case SinglePartition => true
case _ => false
}
}

case object BroadcastPartitioning extends Partitioning {
val numPartitions = 1

override def satisfies(required: Distribution): Boolean = true

override def keyExpressions: Seq[Expression] = Nil
override def guarantees(other: Partitioning): Boolean = other match {
case BroadcastPartitioning => true
case _ => false
}
}

/**
Expand All @@ -127,7 +138,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

private[this] lazy val clusteringSet = expressions.toSet
lazy val clusteringSet = expressions.toSet

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
Expand All @@ -136,7 +147,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = expressions
override def guarantees(other: Partitioning): Boolean = other match {
case o: HashPartitioning =>
this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
case _ => false
}
}

/**
Expand Down Expand Up @@ -170,5 +185,57 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = ordering.map(_.child)
override def guarantees(other: Partitioning): Boolean = other match {
case o: RangePartitioning => this == o
case _ => false
}
}

/**
* A collection of [[Partitioning]]s that can be used to describe the partitioning
* scheme of the output of a physical operator. It is usually used for an operator
* that has multiple children. In this case, a [[Partitioning]] in this collection
* describes how this operator's output is partitioned based on expressions from
* a child. For example, for a Join operator on two tables `A` and `B`
* with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema,
* there are two [[Partitioning]]s can be used to describe how the output of
* this Join operator is partitioned, which are `HashPartitioning(A.key1)` and
* `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
* in this collection do not need to be equivalent, which is useful for
* Outer Join operators.
*/
case class PartitioningCollection(partitionings: Seq[Partitioning])
extends Expression with Partitioning with Unevaluable {

require(
partitionings.map(_.numPartitions).distinct.length == 1,
s"PartitioningCollection requires all of its partitionings have the same numPartitions.")

override def children: Seq[Expression] = partitionings.collect {
case expr: Expression => expr
}

override def nullable: Boolean = false

override def dataType: DataType = IntegerType

override val numPartitions = partitionings.map(_.numPartitions).distinct.head

/**
* Returns true if any `partitioning` of this collection satisfies the given
* [[Distribution]].
*/
override def satisfies(required: Distribution): Boolean =
partitionings.exists(_.satisfies(required))

/**
* Returns true if any `partitioning` of this collection guarantees
* the given [[Partitioning]].
*/
override def guarantees(other: Partitioning): Boolean =
partitionings.exists(_.guarantees(other))

override def toString: String = {
partitionings.map(_.toString).mkString("(", " or ", ")")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite {
}
}

test("HashPartitioning is the output partitioning") {
test("HashPartitioning (with nullSafe = true) is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
child: SparkPlan): SparkPlan = {

def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
if (child.outputPartitioning != partitioning) {
if (!child.outputPartitioning.guarantees(partitioning)) {
Exchange(partitioning, child)
} else {
child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -57,6 +57,8 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug fix.


@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.util.{HashMap => JavaHashMap}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer
Expand All @@ -38,14 +37,6 @@ trait HashOuterJoin {
val left: SparkPlan
val right: SparkPlan

override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case x =>
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}

override def output: Seq[Attribute] = {
joinType match {
case LeftOuter =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

/**
Expand All @@ -37,7 +37,9 @@ case class LeftSemiJoinHash(
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {

override def requiredChildDistribution: Seq[ClusteredDistribution] =
override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

protected override def doExecute(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

/**
Expand All @@ -38,9 +38,10 @@ case class ShuffledHashJoin(
right: SparkPlan)
extends BinaryNode with HashJoin {

override def outputPartitioning: Partitioning = left.outputPartitioning
override def outputPartitioning: Partitioning =
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))

override def requiredChildDistribution: Seq[ClusteredDistribution] =
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

protected override def doExecute(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

Expand All @@ -44,6 +44,14 @@ case class ShuffledHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case x =>
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this change for now. Once we have the nullSafe concept, we can better describe how the result of this join operator is partitioned. For example, right now, it is not safe to say that the output of this operator is partitioned by the rightKeys when we have a left outer join (because rows with null keys are not clustered).

protected override def doExecute(): RDD[InternalRow] = {
val joinedRow = new JoinedRow()
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ case class SortMergeJoin(

override def output: Seq[Attribute] = left.output ++ right.output

override def outputPartitioning: Partitioning = left.outputPartitioning
override def outputPartitioning: Partitioning =
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLConf, execution}
import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}


class PlannerSuite extends SparkFunSuite {
class PlannerSuite extends SparkFunSuite with SQLTestUtils {

override def sqlContext: SQLContext = TestSQLContext

private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
val planned =
Expand Down Expand Up @@ -157,4 +161,45 @@ class PlannerSuite extends SparkFunSuite {
val planned = planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
}

test("PartitioningCollection") {
withTempTable("normal", "small", "tiny") {
testData.registerTempTable("normal")
testData.limit(10).registerTempTable("small")
testData.limit(3).registerTempTable("tiny")

// Disable broadcast join
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
{
val numExchanges = sql(
"""
|SELECT *
|FROM
| normal JOIN small ON (normal.key = small.key)
| JOIN tiny ON (small.key = tiny.key)
""".stripMargin
).queryExecution.executedPlan.collect {
case exchange: Exchange => exchange
}.length
assert(numExchanges === 3)
}

{
// This second query joins on different keys:
val numExchanges = sql(
"""
|SELECT *
|FROM
| normal JOIN small ON (normal.key = small.key)
| JOIN tiny ON (normal.key = tiny.key)
""".stripMargin
).queryExecution.executedPlan.collect {
case exchange: Exchange => exchange
}.length
assert(numExchanges === 3)
}

}
}
}
}