diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index f25d10fec0411..5bd521f0f15aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -60,6 +60,9 @@ case class Exchange( override def output: Seq[Attribute] = child.output + // reset meetPartitions when add `Exchange` + meetPartitions = Set(newPartitioning) + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -303,9 +306,14 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def addOperatorsIfNecessary( partitioning: Partitioning, rowOrdering: Seq[SortOrder], - child: SparkPlan): SparkPlan = { + child: SparkPlan, + isUnaryNodeWithRequire: Boolean = false): SparkPlan = { val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering - val needsShuffle = child.outputPartitioning != partitioning + val needsShuffle = if (isUnaryNodeWithRequire) { + child.outputOrdering != partitioning + } else { + !child.meetPartitions(partitioning) + } val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) if (needSort && needsShuffle && canSortWithShuffle) { @@ -332,6 +340,10 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } if (meetsRequirements && compatible && !needsAnySort) { + val childrenPartition = operator.children.map(_.outputPartitioning).toSet + operator.meetPartitions = operator.children.foldLeft(childrenPartition) { + (m, c) => m ++ c.meetPartitions + } operator } else { // At least one child does not satisfies its required data distribution or @@ -342,14 +354,49 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + // we use outputOrdering to judge whether need shuffle, + // when the operator is unaryNode + if (operator.isInstanceOf[UnaryNode]) { + addOperatorsIfNecessary(SinglePartition, rowOrdering, child, true) + } else { + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + } + case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + // we use outputOrdering to judge whether need shuffle, + // when the operator is unaryNode + if (operator.isInstanceOf[UnaryNode]) { + addOperatorsIfNecessary( + HashPartitioning(clustering, numPartitions), + rowOrdering, + child, + true) + } else { + addOperatorsIfNecessary( + HashPartitioning(clustering, numPartitions), + rowOrdering, + child) + } + case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + // we use outputOrdering to judge whether need shuffle, + // when the operator is unaryNode + if (operator.isInstanceOf[UnaryNode]) { + addOperatorsIfNecessary( + RangePartitioning(ordering, numPartitions), + rowOrdering, + child, + true) + } else { + addOperatorsIfNecessary( + RangePartitioning(ordering, numPartitions), + rowOrdering, + child) + } case (UnspecifiedDistribution, Seq(), child) => child + case (UnspecifiedDistribution, rowOrdering, child) => if (sqlContext.conf.externalSortEnabled) { ExternalSort(rowOrdering, global = false, child) @@ -361,7 +408,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(fixedChildren) + // set the meetPartitions with fixedChildren + val childrenPartition = fixedChildren.map(_.outputPartitioning).toSet + val o = operator.withNewChildren(fixedChildren) + o.meetPartitions = fixedChildren.foldLeft(childrenPartition) { + (m, c) => m ++ c.meetPartitions + } + o } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 435ac011178de..e0635202dadfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -58,6 +58,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ false } + // Use to judge whether need shuffle + var meetPartitions: Set[Partitioning] = Set.empty + /** Overridden make copy also propogates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): this.type = { SparkPlan.currentContext.set(sqlContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 45a7e8fe68f72..730b0b60464af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -142,4 +142,35 @@ class PlannerSuite extends SparkFunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } + + test("unnecessary exchange operators") { + val planned1 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .join(testData3, testData("key") === testData3("a"), "outer") + .queryExecution.executedPlan + val exchanges1 = planned1.collect { case n: Exchange => n } + assert(exchanges1.size === 3) + + val planned2 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .join(testData3, testData2("a") === testData3("a"), "outer") + .queryExecution.executedPlan + val exchange2 = planned2.collect { case n: Exchange => n } + assert(exchange2.size === 3) + + val planned3 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .join(testData3, testData("value") === testData3("a"), "outer") + .queryExecution.executedPlan + val exchages3 = planned3.collect { case n: Exchange => n } + assert(exchages3.size === 4) + + val planned4 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .groupBy(testData2("a")).agg(count('key), count('value)) + .queryExecution.executedPlan + val exchange4 = planned4.collect { case n: Exchange => n } + assert(exchange4.size === 3) + } + }