From 7e52db3e6ff6e2dffe0c599730b8019319fda185 Mon Sep 17 00:00:00 2001 From: jeanlyn Date: Sat, 6 Jun 2015 18:22:57 +0800 Subject: [PATCH 1/3] remove unnecessary exchange --- .../apache/spark/sql/execution/Exchange.scala | 30 ++++++++++++++++-- .../spark/sql/execution/SparkPlan.scala | 9 ++++++ .../spark/sql/execution/PlannerSuite.scala | 31 +++++++++++++++++++ 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index f25d10fec0411..ff5ffc6f3fc1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -303,15 +303,23 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def addOperatorsIfNecessary( partitioning: Partitioning, rowOrdering: Seq[SortOrder], - child: SparkPlan): SparkPlan = { + child: SparkPlan, + isUnaryNodeWithRequire: Boolean = false): SparkPlan = { val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering - val needsShuffle = child.outputPartitioning != partitioning + val needsShuffle = if (isUnaryNodeWithRequire) { + child.outputOrdering != partitioning + } else { + !child.meetPartitions(partitioning) + } val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) if (needSort && needsShuffle && canSortWithShuffle) { Exchange(partitioning, rowOrdering, child) } else { val withShuffle = if (needsShuffle) { + // set meetPartitions to outputPartitioning when need shuffle, + // because shuffle will break the TRANSITIVITY + operator.meetPartitions = Set(operator.outputPartitioning) Exchange(partitioning, Nil, child) } else { child @@ -343,13 +351,29 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + // we use outputOrdering to judge whether need shuffle, + // when the operator is unaryNode + if (operator.isInstanceOf[UnaryNode]) { + addOperatorsIfNecessary( + HashPartitioning(clustering, numPartitions), + rowOrdering, + child, + true) + } else { + addOperatorsIfNecessary( + HashPartitioning(clustering, numPartitions), + rowOrdering, + child) + } + case (OrderedDistribution(ordering), rowOrdering, child) => addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) case (UnspecifiedDistribution, Seq(), child) => child + case (UnspecifiedDistribution, rowOrdering, child) => if (sqlContext.conf.externalSortEnabled) { ExternalSort(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 435ac011178de..289ba83597823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -58,6 +58,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ false } + // Use to judge whether need shuffle + var meetPartitions: Set[Partitioning] = Set.empty + /** Overridden make copy also propogates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): this.type = { SparkPlan.currentContext.set(sqlContext) @@ -194,13 +197,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { self: Product => + meetPartitions = Set(outputPartitioning) } private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { self: Product => override def outputPartitioning: Partitioning = child.outputPartitioning + meetPartitions = child.meetPartitions ++ Set(outputPartitioning) } private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { self: Product => + meetPartitions = + left.meetPartitions ++ + right.meetPartitions ++ + Set(left.outputPartitioning, right.outputPartitioning) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 45a7e8fe68f72..c84028d199691 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -142,4 +142,35 @@ class PlannerSuite extends SparkFunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } + + test("unnecessary exchange operators test") { + val planned1 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .join(testData3, testData("key") === testData3("a"), "outer") + .queryExecution.executedPlan + val exchanges1 = planned1.collect { case n: Exchange => n } + assert(exchanges1.size === 3) + + val planned2 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .join(testData3, testData2("a") === testData3("a"), "outer") + .queryExecution.executedPlan + val exchange2 = planned2.collect { case n: Exchange => n } + assert(exchange2.size === 3) + + val planned3 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .join(testData3, testData("value") === testData3("a"), "outer") + .queryExecution.executedPlan + val exchages3 = planned3.collect { case n: Exchange => n } + assert(exchages3.size === 4) + + val planned4 = testData + .join(testData2, testData("key") === testData2("a"), "outer") + .groupBy(testData2("a")).agg(count('key), count('value)) + .queryExecution.executedPlan + val exchange4 = planned4.collect { case n: Exchange => n } + assert(exchange4.size === 3) + } + } From 6113ca4c9a00e8de62c46e13f42cce44cfc5a24f Mon Sep 17 00:00:00 2001 From: jeanlyn Date: Mon, 8 Jun 2015 13:16:56 +0800 Subject: [PATCH 2/3] fix bug when reset meetParitions fail --- .../apache/spark/sql/execution/Exchange.scala | 18 ++++++++++++++---- .../apache/spark/sql/execution/SparkPlan.scala | 6 ------ .../spark/sql/execution/PlannerSuite.scala | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index ff5ffc6f3fc1f..120b99d2959d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -60,6 +60,9 @@ case class Exchange( override def output: Seq[Attribute] = child.output + // reset meetPartitions when add `Exchange` + meetPartitions = Set(newPartitioning) + /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -317,9 +320,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ Exchange(partitioning, rowOrdering, child) } else { val withShuffle = if (needsShuffle) { - // set meetPartitions to outputPartitioning when need shuffle, - // because shuffle will break the TRANSITIVITY - operator.meetPartitions = Set(operator.outputPartitioning) Exchange(partitioning, Nil, child) } else { child @@ -340,6 +340,10 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } if (meetsRequirements && compatible && !needsAnySort) { + val childrenPartition = operator.children.map(_.outputPartitioning).toSet + operator.meetPartitions = operator.children.foldLeft(childrenPartition) { + (m, c) => m ++ c.meetPartitions + } operator } else { // At least one child does not satisfies its required data distribution or @@ -385,7 +389,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(fixedChildren) + // set the meetPartitions with fixedChildren + val childrenPartition = fixedChildren.map(_.outputPartitioning).toSet + val o = operator.withNewChildren(fixedChildren) + o.meetPartitions = fixedChildren.foldLeft(childrenPartition) { + (m, c) => m ++ c.meetPartitions + } + o } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 289ba83597823..e0635202dadfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -197,19 +197,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { self: Product => - meetPartitions = Set(outputPartitioning) } private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { self: Product => override def outputPartitioning: Partitioning = child.outputPartitioning - meetPartitions = child.meetPartitions ++ Set(outputPartitioning) } private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { self: Product => - meetPartitions = - left.meetPartitions ++ - right.meetPartitions ++ - Set(left.outputPartitioning, right.outputPartitioning) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c84028d199691..730b0b60464af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -143,7 +143,7 @@ class PlannerSuite extends SparkFunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } - test("unnecessary exchange operators test") { + test("unnecessary exchange operators") { val planned1 = testData .join(testData2, testData("key") === testData2("a"), "outer") .join(testData3, testData("key") === testData3("a"), "outer") From 0c1134bfbd0d58c8670899d1b6c78b86e28b2acb Mon Sep 17 00:00:00 2001 From: jeanlyn Date: Mon, 8 Jun 2015 20:27:40 +0800 Subject: [PATCH 3/3] UnaryNode use outputPartitioning to judge whether need shuffle --- .../apache/spark/sql/execution/Exchange.scala | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 120b99d2959d7..5bd521f0f15aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -354,7 +354,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ val fixedChildren = requirements.zipped.map { case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + // we use outputOrdering to judge whether need shuffle, + // when the operator is unaryNode + if (operator.isInstanceOf[UnaryNode]) { + addOperatorsIfNecessary(SinglePartition, rowOrdering, child, true) + } else { + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + } case (ClusteredDistribution(clustering), rowOrdering, child) => // we use outputOrdering to judge whether need shuffle, @@ -373,7 +379,20 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + // we use outputOrdering to judge whether need shuffle, + // when the operator is unaryNode + if (operator.isInstanceOf[UnaryNode]) { + addOperatorsIfNecessary( + RangePartitioning(ordering, numPartitions), + rowOrdering, + child, + true) + } else { + addOperatorsIfNecessary( + RangePartitioning(ordering, numPartitions), + rowOrdering, + child) + } case (UnspecifiedDistribution, Seq(), child) => child