Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35703][SQL][FOLLOWUP] Only eliminate shuffles if partition keys contain all the join keys #35138

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.physical
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand Down Expand Up @@ -380,7 +381,7 @@ trait ShuffleSpec {
/**
* Whether this shuffle spec can be used to create partitionings for the other children.
*/
def canCreatePartitioning: Boolean = false
def canCreatePartitioning: Boolean

/**
* Creates a partitioning that can be used to re-partition the other side with the given
Expand Down Expand Up @@ -412,6 +413,11 @@ case class RangeShuffleSpec(
numPartitions: Int,
distribution: ClusteredDistribution) extends ShuffleSpec {

// `RangePartitioning` is not compatible with any other partitioning since it can't guarantee
// data are co-partitioned for all the children, as range boundaries are randomly sampled. We
// can't let `RangeShuffleSpec` to create a partitioning.
override def canCreatePartitioning: Boolean = false

override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
case SinglePartitionShuffleSpec => numPartitions == 1
case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith)
Expand All @@ -424,8 +430,19 @@ case class RangeShuffleSpec(
case class HashShuffleSpec(
partitioning: HashPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {
lazy val hashKeyPositions: Seq[mutable.BitSet] =
createHashKeyPositions(distribution.clustering, partitioning.expressions)

/**
* A sequence where each element is a set of positions of the hash partition key to the cluster
* keys. For instance, if cluster keys are [a, b, b] and hash partition keys are [a, b], the
* result will be [(0), (1, 2)].
*/
lazy val hashKeyPositions: Seq[mutable.BitSet] = {
val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) =>
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}
partitioning.expressions.map(k => distKeyToPos(k.canonicalized))
}

override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
case SinglePartitionShuffleSpec =>
Expand All @@ -451,30 +468,27 @@ case class HashShuffleSpec(
false
}

override def canCreatePartitioning: Boolean = true
override def canCreatePartitioning: Boolean = {
// To avoid potential data skew, we don't allow `HashShuffleSpec` to create partitioning if
// the hash partition keys are not the full join keys (the cluster keys). Then the planner
// will add shuffles with the default partitioning of `ClusteredDistribution`, which uses all
// the join keys.
if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
partitioning.expressions.length == distribution.clustering.length &&
partitioning.expressions.zip(distribution.clustering).forall {
case (l, r) => l.semanticEquals(r)
}
} else {
true
}
}

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
val exprs = hashKeyPositions.map(v => clustering(v.head))
HashPartitioning(exprs, partitioning.numPartitions)
}

override def numPartitions: Int = partitioning.numPartitions

/**
* Returns a sequence where each element is a set of positions of the key in `hashKeys` to its
* positions in `requiredClusterKeys`. For instance, if `requiredClusterKeys` is [a, b, b] and
* `hashKeys` is [a, b], the result will be [(0), (1, 2)].
*/
private def createHashKeyPositions(
requiredClusterKeys: Seq[Expression],
hashKeys: Seq[Expression]): Seq[mutable.BitSet] = {
val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
requiredClusterKeys.zipWithIndex.foreach { case (distKey, distKeyPos) =>
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}

hashKeys.map(k => distKeyToPos(k.canonicalized))
}
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION =
buildConf("spark.sql.requireAllClusterKeysForCoPartition")
.internal()
.doc("When true, the planner requires all the clustering keys as the hash partition keys " +
"of the children, to eliminate the shuffles for the operator that needs its children to " +
"be co-partitioned, such as JOIN node. This is to avoid data skews which can lead to " +
"significant performance regression if shuffles are eliminated.")
.version("3.3.0")
.booleanConf
.createWithDefault(true)

val RADIX_SORT_ENABLED = buildConf("spark.sql.sort.enableRadixSort")
.internal()
.doc("When true, enable use of radix sort when possible. Radix sort is much faster but " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.internal.SQLConf

class ShuffleSpecSuite extends SparkFunSuite {
class ShuffleSpecSuite extends SparkFunSuite with SQLHelper {
protected def checkCompatible(
left: ShuffleSpec,
right: ShuffleSpec,
Expand Down Expand Up @@ -349,12 +350,22 @@ class ShuffleSpecSuite extends SparkFunSuite {

test("canCreatePartitioning") {
val distribution = ClusteredDistribution(Seq($"a", $"b"))
assert(HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution).canCreatePartitioning)
withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false") {
assert(HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution).canCreatePartitioning)
}
withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "true") {
assert(!HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution)
.canCreatePartitioning)
assert(HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), distribution)
.canCreatePartitioning)
}
assert(SinglePartitionShuffleSpec.canCreatePartitioning)
assert(ShuffleSpecCollection(Seq(
withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false") {
assert(ShuffleSpecCollection(Seq(
HashShuffleSpec(HashPartitioning(Seq($"a"), 10), distribution),
HashShuffleSpec(HashPartitioning(Seq($"a", $"b"), 10), distribution)))
.canCreatePartitioning)
.canCreatePartitioning)
}
assert(!RangeShuffleSpec(10, distribution).canCreatePartitioning)
}

Expand Down
Loading