diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 98c3abe93b553..8231ad9134613 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -75,11 +75,18 @@ object Partitioner { * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will * produce an unexpected or incorrect result. */ -class HashPartitioner(partitions: Int) extends Partitioner { +class HashPartitioner(partitions: Int, buckets: Int = 0) extends Partitioner { require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.") + require(buckets >= 0, s"Number of buckets ($buckets) cannot be negative.") + + def this(partitions: Int) { + this(partitions , 0) + } def numPartitions: Int = partitions + def numBuckets: Int = buckets + def getPartition(key: Any): Int = key match { case null => 0 case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 195b775730f02..48f6edcf4ef2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -253,10 +253,10 @@ case object SinglePartition extends Partitioning { * in the same partition. Moreover while evaluating expressions if they are given in different order * than this partitioning then also it is considered equal. */ -case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions: Int) +case class OrderlessHashPartitioning(expressions: Seq[Expression], + numPartitions: Int, numBuckets: Int) extends Expression with Partitioning with Unevaluable { - override def children: Seq[Expression] = expressions override def nullable: Boolean = false override def dataType: DataType = IntegerType @@ -274,6 +274,7 @@ case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions } private def anyOrderEquals(other: HashPartitioning) : Boolean = { + other.numBuckets == this.numBuckets && other.numPartitions == this.numPartitions && matchExpressions(other.expressions) } @@ -284,7 +285,7 @@ case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions } override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => anyOrderEquals(o) + case p: HashPartitioning => anyOrderEquals(p) case _ => false } @@ -295,8 +296,8 @@ case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be * in the same partition. */ -case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression with Partitioning with Unevaluable { +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int, + numBuckets : Int = 0 ) extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 446571aa8409f..b33dc40615351 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -47,10 +47,11 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def createPartitioning( requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { + numPartitions: Int, numBuckets: Int = 0): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering) => + HashPartitioning(clustering, numPartitions, numBuckets) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } @@ -180,10 +181,20 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // partitioned by the same partitioning into the same number of partitions. In that case, // don't try to make them match `defaultPartitions`, just use the existing partitioning. val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max + val numBuckets = { + children.map(child => { + if (child.outputPartitioning.isInstanceOf[OrderlessHashPartitioning]) { + child.outputPartitioning.asInstanceOf[OrderlessHashPartitioning].numBuckets + } + else { + 0 + } + }).reduceLeft(_ max _) + } val useExistingPartitioning = children.zip(requiredChildDistributions).forall { case (child, distribution) => child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + createPartitioning(distribution, maxChildrenNumPartitions, numBuckets)) } children = if (useExistingPartitioning) { @@ -205,10 +216,20 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // number of partitions. Otherwise, we use maxChildrenNumPartitions. if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions } - + val numBuckets = { + children.map(child => { + if (child.outputPartitioning.isInstanceOf[OrderlessHashPartitioning]) { + child.outputPartitioning.asInstanceOf[OrderlessHashPartitioning].numBuckets + } + else { + 0 + } + }).reduceLeft(_ max _) + } children.zip(requiredChildDistributions).map { case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) + val targetPartitioning = createPartitioning(distribution, + numPartitions, numBuckets) if (child.outputPartitioning.guarantees(targetPartitioning)) { child } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index afe0fbea73bd9..65dfd2d334807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -25,8 +25,8 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -201,13 +201,7 @@ object ShuffleExchange { serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) - case HashPartitioning(_, n) => - new Partitioner { - override def numPartitions: Int = n - // For HashPartitioning, the partitioning key is already a valid partition ID, as we use - // `HashPartitioning.partitionIdExpression` to produce partitioning key. - override def getPartition(key: Any): Int = key.asInstanceOf[Int] - } + case HashPartitioning(_, n, b) => new HashPartitioner(n, b) case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys.