Skip to content

Commit

Permalink
Fix parallelism in join operator unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Aug 11, 2015
1 parent 899dce2 commit e79909e
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,67 +34,75 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
rightRows: DataFrame,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>

def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}

def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}

def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}

test(s"$testName using BroadcastHashJoin (build=left)") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>

def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}

def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}

def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}

test(s"$testName using BroadcastHashJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using BroadcastHashJoin (build=right)") {
test(s"$testName using BroadcastHashJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using ShuffledHashJoin (build=left)") {
test(s"$testName using ShuffledHashJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using ShuffledHashJoin (build=right)") {
test(s"$testName using ShuffledHashJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using SortMergeJoin") {
test(s"$testName using SortMergeJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeSortMergeJoin(left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,52 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
joinType: JoinType,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using ShuffledHashOuterJoin") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using ShuffledHashOuterJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using SortMergeOuterJoin") {
test(s"$testName using SortMergeOuterJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
EnsureRequirements(sqlContext).apply(
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
}
}
}
}
}

test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
Expand All @@ -85,14 +94,19 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, 1.0),
Row(3, 3.0),
Row(5, 1.0),
Row(6, 6.0),
Row(null, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))

val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(0, 0.0),
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(5, 3.0),
Row(7, 7.0),
Row(null, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))

Expand All @@ -117,7 +131,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
(2, 1.0, 2, 3.0),
(2, 1.0, 2, 3.0),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
(3, 3.0, null, null),
(5, 1.0, 5, 3.0),
(6, 6.0, null, null)
)
)

Expand All @@ -129,12 +145,15 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
condition,
Seq(
(null, null, null, null),
(null, null, 0, 0.0),
(2, 1.0, 2, 3.0),
(2, 1.0, 2, 3.0),
(2, 1.0, 2, 3.0),
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
(null, null, 4, 1.0),
(5, 1.0, 5, 3.0),
(null, null, 7, 7.0)
)
)

Expand All @@ -151,8 +170,12 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
(2, 1.0, 2, 3.0),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null),
(5, 1.0, 5, 3.0),
(6, 6.0, null, null),
(null, null, 0, 0.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0),
(null, null, 7, 7.0),
(null, null, null, null),
(null, null, null, null)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,31 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
rightRows: DataFrame,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using LeftSemiJoinHash") {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using LeftSemiJoinHash") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext).apply(
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}

test(s"$testName using BroadcastLeftSemiJoinHash") {
test(s"$testName using BroadcastLeftSemiJoinHash") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}

test(s"$testName using LeftSemiJoinBNL") {
test(s"$testName using LeftSemiJoinBNL") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
LeftSemiJoinBNL(left, right, Some(condition)),
expectedAnswer.map(Row.fromTuple),
Expand Down

0 comments on commit e79909e

Please sign in to comment.