Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Jan 28, 2022
1 parent 4920014 commit 0751f5c
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,26 @@ case object AllTuples extends Distribution {
}
}

/**
* A subtype of [[Distribution]] whose tuples are clustered according to the clustering
* `expressions`.
*/
sealed trait Clustering extends Distribution {
/**
* The expressions used to cluster the tuples in this distribution.
*/
def expressions: Seq[Expression]
}

/**
* Represents data where tuples that share the same values for the `clustering`
* [[Expression Expressions]] will be co-located in the same partition.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution {
expressions: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution with Clustering {
require(
clustering != Nil,
expressions != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
Expand All @@ -86,7 +97,32 @@ case class ClusteredDistribution(
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(clustering, numPartitions)
HashPartitioning(expressions, numPartitions)
}
}

/**
* Represents data where tuples have been clustered according to the hash of the given
* `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only
* [[HashPartitioning]] can satisfy this distribution.
*
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
* number of partitions, this distribution strictly requires which partition the tuple should be in.
*/
case class HashClusteredDistribution(
expressions: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution with Clustering {
require(
expressions != Nil,
"The expressions for hash of a HashClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(expressions, numPartitions)
}
}

Expand Down Expand Up @@ -157,7 +193,7 @@ trait Partitioning {
*
* @param distribution the required clustered distribution for this partitioning
*/
def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
def createShuffleSpec(distribution: Clustering): ShuffleSpec =
throw new IllegalStateException(s"Unexpected partitioning: ${getClass.getSimpleName}")

/**
Expand Down Expand Up @@ -192,7 +228,7 @@ case object SinglePartition extends Partitioning {
case _ => true
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
override def createShuffleSpec(clustering: Clustering): ShuffleSpec =
SinglePartitionShuffleSpec
}

Expand All @@ -211,15 +247,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, _) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
}
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
HashShuffleSpec(this, distribution)
override def createShuffleSpec(clustering: Clustering): ShuffleSpec =
HashShuffleSpec(this, clustering)

/**
* Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
Expand Down Expand Up @@ -279,8 +319,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
}
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
RangeShuffleSpec(this.numPartitions, distribution)
override def createShuffleSpec(clustering: Clustering): ShuffleSpec =
RangeShuffleSpec(numPartitions)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): RangePartitioning =
Expand Down Expand Up @@ -324,9 +364,9 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
override def satisfies0(required: Distribution): Boolean =
partitionings.exists(_.satisfies(required))

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
val filtered = partitionings.filter(_.satisfies(distribution))
ShuffleSpecCollection(filtered.map(_.createShuffleSpec(distribution)))
override def createShuffleSpec(clustering: Clustering): ShuffleSpec = {
val filtered = partitionings.filter(_.satisfies(clustering))
ShuffleSpecCollection(filtered.map(_.createShuffleSpec(clustering)))
}

override def toString: String = {
Expand Down Expand Up @@ -409,9 +449,7 @@ case object SinglePartitionShuffleSpec extends ShuffleSpec {
override def numPartitions: Int = 1
}

case class RangeShuffleSpec(
numPartitions: Int,
distribution: ClusteredDistribution) extends ShuffleSpec {
case class RangeShuffleSpec(numPartitions: Int) extends ShuffleSpec {

// `RangePartitioning` is not compatible with any other partitioning since it can't guarantee
// data are co-partitioned for all the children, as range boundaries are randomly sampled. We
Expand All @@ -429,7 +467,7 @@ case class RangeShuffleSpec(

case class HashShuffleSpec(
partitioning: HashPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {
clustering: Clustering) extends ShuffleSpec {

/**
* A sequence where each element is a set of positions of the hash partition key to the cluster
Expand All @@ -438,7 +476,7 @@ case class HashShuffleSpec(
*/
lazy val hashKeyPositions: Seq[mutable.BitSet] = {
val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) =>
clustering.expressions.zipWithIndex.foreach { case (distKey, distKeyPos) =>
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}
partitioning.expressions.map(k => distKeyToPos.getOrElse(k.canonicalized, mutable.BitSet.empty))
Expand All @@ -447,14 +485,14 @@ case class HashShuffleSpec(
override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
case SinglePartitionShuffleSpec =>
partitioning.numPartitions == 1
case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherDistribution) =>
case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherClustering) =>
// we need to check:
// 1. both distributions have the same number of clustering expressions
// 2. both partitioning have the same number of partitions
// 3. both partitioning have the same number of expressions
// 4. each pair of expression from both has overlapping positions in their
// corresponding distributions.
distribution.clustering.length == otherDistribution.clustering.length &&
clustering.expressions.length == otherClustering.expressions.length &&
partitioning.numPartitions == otherPartitioning.numPartitions &&
partitioning.expressions.length == otherPartitioning.expressions.length && {
val otherHashKeyPositions = otherHashSpec.hashKeyPositions
Expand All @@ -474,8 +512,8 @@ case class HashShuffleSpec(
// will add shuffles with the default partitioning of `ClusteredDistribution`, which uses all
// the join keys.
if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
partitioning.expressions.length == distribution.clustering.length &&
partitioning.expressions.zip(distribution.clustering).forall {
partitioning.expressions.length == clustering.expressions.length &&
partitioning.expressions.zip(clustering.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
spec: ShuffleSpec,
dist: ClusteredDistribution,
expected: Partitioning): Unit = {
val actual = spec.createPartitioning(dist.clustering)
val actual = spec.createPartitioning(dist.expressions)
if (actual != expected) {
fail(
s"""
Expand Down Expand Up @@ -190,12 +190,12 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
checkCompatible(
HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(10),
expected = false
)

checkCompatible(
RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(10),
HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10),
ClusteredDistribution(Seq($"a", $"b"))),
expected = false
Expand Down Expand Up @@ -268,82 +268,82 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {

checkCompatible(
SinglePartitionShuffleSpec,
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(1),
expected = true
)

checkCompatible(
SinglePartitionShuffleSpec,
ShuffleSpecCollection(Seq(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))), SinglePartitionShuffleSpec)),
RangeShuffleSpec(1), SinglePartitionShuffleSpec)),
expected = true
)

checkCompatible(
RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(10),
RangeShuffleSpec(10),
expected = false
)

checkCompatible(
RangeShuffleSpec(10, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(10),
SinglePartitionShuffleSpec,
expected = false
)

checkCompatible(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(1),
SinglePartitionShuffleSpec,
expected = true
)

checkCompatible(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(1),
ShuffleSpecCollection(Seq(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))), SinglePartitionShuffleSpec)),
RangeShuffleSpec(1), SinglePartitionShuffleSpec)),
expected = true
)

checkCompatible(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(1),
ShuffleSpecCollection(Seq(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(1, ClusteredDistribution(Seq($"c", $"d"))))),
RangeShuffleSpec(1),
RangeShuffleSpec(1))),
expected = false
)

checkCompatible(
ShuffleSpecCollection(Seq(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))), SinglePartitionShuffleSpec)),
RangeShuffleSpec(1), SinglePartitionShuffleSpec)),
SinglePartitionShuffleSpec,
expected = true
)

checkCompatible(
ShuffleSpecCollection(Seq(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))), SinglePartitionShuffleSpec)),
RangeShuffleSpec(1), SinglePartitionShuffleSpec)),
ShuffleSpecCollection(Seq(
SinglePartitionShuffleSpec, RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))))),
SinglePartitionShuffleSpec, RangeShuffleSpec(1))),
expected = true
)

checkCompatible(
ShuffleSpecCollection(Seq(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))), SinglePartitionShuffleSpec)),
RangeShuffleSpec(1), SinglePartitionShuffleSpec)),
ShuffleSpecCollection(Seq(
HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 1),
ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))))),
RangeShuffleSpec(1))),
expected = true
)

checkCompatible(
ShuffleSpecCollection(Seq(
RangeShuffleSpec(1, ClusteredDistribution(Seq($"a", $"b"))), SinglePartitionShuffleSpec)),
RangeShuffleSpec(1), SinglePartitionShuffleSpec)),
ShuffleSpecCollection(Seq(
HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 2),
ClusteredDistribution(Seq($"a", $"b"))),
RangeShuffleSpec(2, ClusteredDistribution(Seq($"a", $"b"))))),
RangeShuffleSpec(2))),
expected = false
)
}
Expand All @@ -366,7 +366,7 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), distribution)))
.canCreatePartitioning)
}
assert(!RangeShuffleSpec(10, distribution).canCreatePartitioning)
assert(!RangeShuffleSpec(10).canCreatePartitioning)
}

test("createPartitioning: HashShuffleSpec") {
Expand Down Expand Up @@ -412,15 +412,15 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {

checkCreatePartitioning(ShuffleSpecCollection(Seq(
HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution),
RangeShuffleSpec(10, distribution))),
RangeShuffleSpec(10))),
ClusteredDistribution(Seq($"c", $"d")),
HashPartitioning(Seq($"c"), 10)
)

// unsupported cases

val msg = intercept[Exception](RangeShuffleSpec(10, distribution)
.createPartitioning(distribution.clustering))
val msg = intercept[Exception](RangeShuffleSpec(10)
.createPartitioning(distribution.expressions))
assert(msg.getMessage.contains("Operation unsupported"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object AQEUtils {
case p: ProjectExec =>
getRequiredDistribution(p.child).flatMap {
case h: ClusteredDistribution =>
if (h.clustering.forall(e => p.projectList.exists(_.semanticEquals(e)))) {
if (h.expressions.forall(e => p.projectList.exists(_.semanticEquals(e)))) {
Some(h)
} else {
// It's possible that the user-specified repartition is effective but the output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class DataSourcePartitioning(
override def satisfies0(required: physical.Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case d: physical.ClusteredDistribution if isCandidate(d.clustering) =>
val attrs = d.clustering.map(_.asInstanceOf[Attribute])
case d: physical.ClusteredDistribution if isCandidate(d.expressions) =>
val attrs = d.expressions.map(_.asInstanceOf[Attribute])
partitioning.satisfy(
new ClusteredDistribution(attrs.map { a =>
val name = colNames.get(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ case class EnsureRequirements(
} else {
val newPartitioning = bestSpecOpt.map { bestSpec =>
// Use the best spec to create a new partitioning to re-shuffle this child
val clustering = dist.asInstanceOf[ClusteredDistribution].clustering
val clustering = dist.asInstanceOf[ClusteredDistribution].expressions
bestSpec.createPartitioning(clustering)
}.getOrElse {
// No best spec available, so we create default partitioning from the required
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ case class StreamingSymmetricHashJoinExec(
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil

override def output: Seq[Attribute] = joinType match {
case _: InnerLike => left.output ++ right.output
Expand Down

0 comments on commit 0751f5c

Please sign in to comment.