Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNAP-656 Delink RDD partitions from buckets #4

Merged
merged 3 commits into from
Sep 1, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,18 @@ object Partitioner {
* so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
* produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
class HashPartitioner(partitions: Int, buckets: Int = 0) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
require(buckets >= 0, s"Number of buckets ($buckets) cannot be negative.")

def this(partitions: Int) {
this(partitions , 0)
}

def numPartitions: Int = partitions

def numBuckets: Int = buckets

def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,10 @@ case object SinglePartition extends Partitioning {
* in the same partition. Moreover while evaluating expressions if they are given in different order
* than this partitioning then also it is considered equal.
*/
case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case class OrderlessHashPartitioning(expressions: Seq[Expression],
numPartitions: Int, numBuckets: Int)
extends Expression with Partitioning with Unevaluable {


override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
Expand All @@ -274,6 +274,7 @@ case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions
}

private def anyOrderEquals(other: HashPartitioning) : Boolean = {
other.numBuckets == this.numBuckets &&
other.numPartitions == this.numPartitions &&
matchExpressions(other.expressions)
}
Expand All @@ -284,7 +285,7 @@ case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions
}

override def guarantees(other: Partitioning): Boolean = other match {
case o: HashPartitioning => anyOrderEquals(o)
case p: HashPartitioning => anyOrderEquals(p)
case _ => false
}

Expand All @@ -295,8 +296,8 @@ case class OrderlessHashPartitioning(expressions: Seq[Expression], numPartitions
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression with Partitioning with Unevaluable {
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int,
numBuckets : Int = 0 ) extends Expression with Partitioning with Unevaluable {

override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
*/
private def createPartitioning(
requiredDistribution: Distribution,
numPartitions: Int): Partitioning = {
numPartitions: Int, numBuckets: Int = 0): Partitioning = {
requiredDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
case ClusteredDistribution(clustering) =>
HashPartitioning(clustering, numPartitions, numBuckets)
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
}
Expand Down Expand Up @@ -180,10 +181,20 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// partitioned by the same partitioning into the same number of partitions. In that case,
// don't try to make them match `defaultPartitions`, just use the existing partitioning.
val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
val numBuckets = {
children.map(child => {
if (child.outputPartitioning.isInstanceOf[OrderlessHashPartitioning]) {
child.outputPartitioning.asInstanceOf[OrderlessHashPartitioning].numBuckets
}
else {
0
}
}).reduceLeft(_ max _)
}
val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
case (child, distribution) =>
child.outputPartitioning.guarantees(
createPartitioning(distribution, maxChildrenNumPartitions))
createPartitioning(distribution, maxChildrenNumPartitions, numBuckets))
}

children = if (useExistingPartitioning) {
Expand All @@ -205,10 +216,20 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// number of partitions. Otherwise, we use maxChildrenNumPartitions.
if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
}

val numBuckets = {
children.map(child => {
if (child.outputPartitioning.isInstanceOf[OrderlessHashPartitioning]) {
child.outputPartitioning.asInstanceOf[OrderlessHashPartitioning].numBuckets
}
else {
0
}
}).reduceLeft(_ max _)
}
children.zip(requiredChildDistributions).map {
case (child, distribution) =>
val targetPartitioning = createPartitioning(distribution, numPartitions)
val targetPartitioning = createPartitioning(distribution,
numPartitions, numBuckets)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -201,13 +201,7 @@ object ShuffleExchange {
serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
new Partitioner {
override def numPartitions: Int = n
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
case HashPartitioning(_, n, b) => new HashPartitioner(n, b)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
Expand Down