Skip to content

Commit

Permalink
more test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin Uang committed Jan 14, 2019
1 parent c89f03b commit 563706e
Showing 1 changed file with 15 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.{RangeExec, SortExec}
import org.apache.spark.sql.execution.RangeExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.joins.{BuildRight, ShuffledHashJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

Expand Down Expand Up @@ -53,7 +52,7 @@ class PlanQueryStageTest extends SharedSQLContext {
conf.setConfString("spark.sql.exchange.reuse", "true")

val planQueryStage = PlanQueryStage(conf)
val newPlan = planQueryStage(createMergeJoinPlan(100, 100))
val newPlan = planQueryStage(createJoinExec(100, 100))

val collected = newPlan.collect {
case e: ShuffleQueryStageInput => e.childStage
Expand All @@ -68,7 +67,7 @@ class PlanQueryStageTest extends SharedSQLContext {
conf.setConfString("spark.sql.exchange.reuse", "true")

val planQueryStage = PlanQueryStage(conf)
val newPlan = planQueryStage(createMergeJoinPlan(100, 101))
val newPlan = planQueryStage(createJoinExec(100, 101))

val collected = newPlan.collect {
case e: ShuffleQueryStageInput => e.childStage
Expand All @@ -78,25 +77,20 @@ class PlanQueryStageTest extends SharedSQLContext {
assert(!collected(0).eq(collected(1)))
}

def createMergeJoinPlan(leftNum: Int, rightNum: Int): SortMergeJoinExec = {
val left = SortExec(
Seq(SortOrder(UnresolvedAttribute("blah"), Ascending)),
true,
ShuffleExchangeExec(
HashPartitioning(Seq(UnresolvedAttribute("blah")), 100),
RangeExec(org.apache.spark.sql.catalyst.plans.logical.Range(1, leftNum, 1, 1))))

val right = SortExec(
Seq(SortOrder(UnresolvedAttribute("blah"), Ascending)),
true,
ShuffleExchangeExec(
HashPartitioning(Seq(UnresolvedAttribute("blah")), 100),
RangeExec(org.apache.spark.sql.catalyst.plans.logical.Range(1, rightNum, 1, 1))))

SortMergeJoinExec(
def createJoinExec(leftNum: Int, rightNum: Int): ShuffledHashJoinExec = {
val left = ShuffleExchangeExec(
HashPartitioning(Seq(UnresolvedAttribute("blah")), 100),
RangeExec(org.apache.spark.sql.catalyst.plans.logical.Range(1, leftNum, 1, 1)))

val right = ShuffleExchangeExec(
HashPartitioning(Seq(UnresolvedAttribute("blah")), 100),
RangeExec(org.apache.spark.sql.catalyst.plans.logical.Range(1, rightNum, 1, 1)))

ShuffledHashJoinExec(
Seq(UnresolvedAttribute("blah")),
Seq(UnresolvedAttribute("blah")),
Inner,
BuildRight,
None,
left,
right)
Expand Down

0 comments on commit 563706e

Please sign in to comment.