diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index debbd8d7c26c9..437bbaae1968b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -21,6 +21,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.math.log10 import scala.reflect.ClassTag import scala.util.hashing.byteswap32 @@ -42,7 +43,9 @@ object Partitioner { /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * - * If any of the RDDs already has a partitioner, choose that one. + * If any of the RDDs already has a partitioner, and the number of partitions of the + * partitioner is either greater than or is less than and within a single order of + * magnitude of the max number of upstream partitions, choose that one. * * Otherwise, we use a default HashPartitioner. For the number of partitions, if * spark.default.parallelism is set, then we'll use the value from SparkContext @@ -57,8 +60,15 @@ object Partitioner { def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val rdds = (Seq(rdd) ++ others) val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0)) - if (hasPartitioner.nonEmpty) { - hasPartitioner.maxBy(_.partitions.length).partitioner.get + + val hasMaxPartitioner: Option[RDD[_]] = if (hasPartitioner.nonEmpty) { + Some(hasPartitioner.maxBy(_.partitions.length)) + } else { + None + } + + if (isEligiblePartitioner(hasMaxPartitioner, rdds)) { + hasMaxPartitioner.get.partitioner.get } else { if (rdd.context.conf.contains("spark.default.parallelism")) { new HashPartitioner(rdd.context.defaultParallelism) @@ -67,6 +77,22 @@ object Partitioner { } } } + + /** + * Returns true if the number of partitions of the RDD is either greater + * than or is less than and within a single order of magnitude of the + * max number of upstream partitions; + * otherwise, returns false + */ + private def isEligiblePartitioner( + hasMaxPartitioner: Option[RDD[_]], + rdds: Seq[RDD[_]]): Boolean = { + if (hasMaxPartitioner.isEmpty) { + return false + } + val maxPartitions = rdds.map(_.partitions.length).max + log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1 + } } /** diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index dfe4c25670ce0..155ca17db726b 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -259,6 +259,33 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva val partitioner = new RangePartitioner(22, rdd) assert(partitioner.numPartitions === 3) } + + test("defaultPartitioner") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) + val rdd2 = sc + .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(10)) + val rdd3 = sc + .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + .partitionBy(new HashPartitioner(100)) + val rdd4 = sc + .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(9)) + val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) + + val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) + val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) + val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) + val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) + val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5) + + assert(partitioner1.numPartitions == rdd1.getNumPartitions) + assert(partitioner2.numPartitions == rdd3.getNumPartitions) + assert(partitioner3.numPartitions == rdd3.getNumPartitions) + assert(partitioner4.numPartitions == rdd3.getNumPartitions) + assert(partitioner5.numPartitions == rdd4.getNumPartitions) + + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 65d35264dc108..a39e0469272fe 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -310,6 +310,28 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(joined.size > 0) } + // See SPARK-22465 + test("cogroup between multiple RDD " + + "with an order of magnitude difference in number of partitions") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000) + val rdd2 = sc + .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd1.getNumPartitions) + } + + // See SPARK-22465 + test("cogroup between multiple RDD" + + " with number of partitions similar in order of magnitude") { + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc + .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))