-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22465][Core] Add a safety-check to RDD defaultPartitioner #20002
Changes from 4 commits
176270b
be391a7
4b2dcac
ca6aa08
961e384
8b35452
4729d80
7d88e6c
6623227
3dd1ad8
62b17e9
3b08951
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} | |
|
||
import scala.collection.mutable | ||
import scala.collection.mutable.ArrayBuffer | ||
import scala.math.log10 | ||
import scala.reflect.ClassTag | ||
import scala.util.hashing.byteswap32 | ||
|
||
|
@@ -42,7 +43,9 @@ object Partitioner { | |
/** | ||
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs. | ||
* | ||
* If any of the RDDs already has a partitioner, choose that one. | ||
* If any of the RDDs already has a partitioner, and the number of partitions of the | ||
* partitioner is either greater than or is less than and within a single order of | ||
* magnitude of the max number of upstream partitions, choose that one. | ||
* | ||
* Otherwise, we use a default HashPartitioner. For the number of partitions, if | ||
* spark.default.parallelism is set, then we'll use the value from SparkContext | ||
|
@@ -57,7 +60,8 @@ object Partitioner { | |
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { | ||
val rdds = (Seq(rdd) ++ others) | ||
val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) | ||
if (hasPartitioner.nonEmpty) { | ||
if (hasPartitioner.nonEmpty | ||
&& isEligiblePartitioner(hasPartitioner.maxBy(_.partitions.length), rdds)) { | ||
hasPartitioner.maxBy(_.partitions.length).partitioner.get | ||
} else { | ||
if (rdd.context.conf.contains("spark.default.parallelism")) { | ||
|
@@ -67,6 +71,16 @@ object Partitioner { | |
} | ||
} | ||
} | ||
|
||
/** | ||
* Returns true if the number of partitions of the RDD is either greater than or is | ||
* less than and within a single order of magnitude of the max number of upstream partitions; | ||
* otherwise, returns false | ||
*/ | ||
private def isEligiblePartitioner(hasMaxPartitioner: RDD[_], rdds: Seq[RDD[_]]): Boolean = { | ||
val maxPartitions = rdds.map(_.partitions.length).max | ||
log10(maxPartitions).floor - log10(hasMaxPartitioner.getNumPartitions).floor < 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @mridulm , I suppose I was trying to ensure a strict order-of-magnitude check; but, I agree it leads to a discontinuity. I will change this, and the corresponding test cases. |
||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -259,6 +259,27 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva | |
val partitioner = new RangePartitioner(22, rdd) | ||
assert(partitioner.numPartitions === 3) | ||
} | ||
|
||
test("defaultPartitioner") { | ||
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) | ||
val rdd2 = sc | ||
.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) | ||
.partitionBy(new HashPartitioner(10)) | ||
val rdd3 = sc | ||
.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) | ||
.partitionBy(new HashPartitioner(100)) | ||
|
||
val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) | ||
val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) | ||
val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) | ||
val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) | ||
|
||
assert(partitioner1.numPartitions == rdd1.getNumPartitions) | ||
assert(partitioner2.numPartitions == rdd3.getNumPartitions) | ||
assert(partitioner3.numPartitions == rdd3.getNumPartitions) | ||
assert(partitioner4.numPartitions == rdd3.getNumPartitions) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a testcase such that numPartitions 9 vs 11 is not treated as an order of magnitude jump (to prevent future changes which end up breaking this). |
||
|
||
} | ||
} | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hasPartitioner.maxBy(_.partitions.length)
is used repeatedly, pull that into a variable ?