Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22465][Core] Add a safety-check to RDD defaultPartitioner #20002

Closed
wants to merge 12 commits into from
Closed
18 changes: 16 additions & 2 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.log10
import scala.reflect.ClassTag
import scala.util.hashing.byteswap32

Expand All @@ -42,7 +43,9 @@ object Partitioner {
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
*
* If any of the RDDs already has a partitioner, choose that one.
* If any of the RDDs already has a partitioner, and the number of partitions of the
* partitioner is either greater than or is less than and within a single order of
* magnitude of the max number of upstream partitions, choose that one.
*
* Otherwise, we use a default HashPartitioner. For the number of partitions, if
* spark.default.parallelism is set, then we'll use the value from SparkContext
Expand All @@ -57,7 +60,8 @@ object Partitioner {
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val rdds = (Seq(rdd) ++ others)
val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
if (hasPartitioner.nonEmpty) {
if (hasPartitioner.nonEmpty
&& isEligiblePartitioner(hasPartitioner.maxBy(_.partitions.length), rdds)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasPartitioner.maxBy(_.partitions.length) is used repeatedly, pull that into a variable ?

hasPartitioner.maxBy(_.partitions.length).partitioner.get
} else {
if (rdd.context.conf.contains("spark.default.parallelism")) {
Expand All @@ -67,6 +71,16 @@ object Partitioner {
}
}
}

/**
* Returns true if the number of partitions of the RDD is either greater than or is
* less than and within a single order of magnitude of the max number of upstream partitions;
* otherwise, returns false
*/
private def isEligiblePartitioner(hasMaxPartitioner: RDD[_], rdds: Seq[RDD[_]]): Boolean = {
val maxPartitions = rdds.map(_.partitions.length).max
log10(maxPartitions).floor - log10(hasMaxPartitioner.getNumPartitions).floor < 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why .floor ?
It causes unnecessary discontinuity imo, for example: (9, 11) will not satisfy - but it should.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mridulm , I suppose I was trying to ensure a strict order-of-magnitude check; but, I agree it leads to a discontinuity. I will change this, and the corresponding test cases.

}
}

/**
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,27 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva
val partitioner = new RangePartitioner(22, rdd)
assert(partitioner.numPartitions === 3)
}

test("defaultPartitioner") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150)
val rdd2 = sc
.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
.partitionBy(new HashPartitioner(10))
val rdd3 = sc
.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14)))
.partitionBy(new HashPartitioner(100))

val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2)
val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3)
val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1)
val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3)

assert(partitioner1.numPartitions == rdd1.getNumPartitions)
assert(partitioner2.numPartitions == rdd3.getNumPartitions)
assert(partitioner3.numPartitions == rdd3.getNumPartitions)
assert(partitioner4.numPartitions == rdd3.getNumPartitions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a testcase such that numPartitions 9 vs 11 is not treated as an order of magnitude jump (to prevent future changes which end up breaking this).


}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,28 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(joined.size > 0)
}

// See SPARK-22465
test("cogroup between multiple RDD " +
"with an order of magnitude difference in number of partitions") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000)
val rdd2 = sc
.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
.partitionBy(new HashPartitioner(10))
val joined = rdd1.cogroup(rdd2)
assert(joined.getNumPartitions == rdd1.getNumPartitions)
}

// See SPARK-22465
test("cogroup between multiple RDD" +
" with number of partitions similar in order of magnitude") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20)
val rdd2 = sc
.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
.partitionBy(new HashPartitioner(10))
val joined = rdd1.cogroup(rdd2)
assert(joined.getNumPartitions == rdd2.getNumPartitions)
}

test("rightOuterJoin") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
Expand Down