diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5342c8ee6d672..040f1bfab65b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -101,6 +101,8 @@ case class ClusteredDistribution( * Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the * stateful operator, only [[HashPartitioning]] (and HashPartitioning in * [[PartitioningCollection]]) can satisfy this distribution. + * When `_requiredNumPartitions` is 1, [[SinglePartition]] is essentially same as + * [[HashPartitioning]], so it can satisfy this distribution as well. * * NOTE: This is applied only to stream-stream join as of now. For other stateful operators, we * have been using ClusteredDistribution, which could construct the physical partitioning of the diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 5d3f960c3bfac..e047d4c070bec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal, Murmur3Hash, Pmod} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.IntegerType class DistributionSuite extends SparkFunSuite { @@ -265,4 +267,80 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)), false) } + + test("Structured Streaming output partitioning and distribution") { + // Validate HashPartitioning.partitionIdExpression to be exactly expected format, because + // Structured Streaming state store requires it to be consistent across Spark versions. + val expressions = Seq($"a", $"b", $"c") + val hashPartitioning = HashPartitioning(expressions, 10) + hashPartitioning.partitionIdExpression match { + case Pmod(Murmur3Hash(es, 42), Literal(10, IntegerType), _) => + assert(es.length == expressions.length && es.zip(expressions).forall { + case (l, r) => l.semanticEquals(r) + }) + case x => fail(s"Unexpected partitionIdExpression $x for $hashPartitioning") + } + + // Validate only HashPartitioning (and HashPartitioning in PartitioningCollection) can satisfy + // StatefulOpClusteredDistribution. SinglePartition can also satisfy this distribution when + // `_requiredNumPartitions` is 1. + checkSatisfied( + HashPartitioning(Seq($"a", $"b", $"c"), 10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + true) + + checkSatisfied( + PartitioningCollection(Seq( + HashPartitioning(Seq($"a", $"b", $"c"), 10), + RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10))), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + true) + + checkSatisfied( + SinglePartition, + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 1), + true) + + checkSatisfied( + PartitioningCollection(Seq( + HashPartitioning(Seq($"a", $"b"), 1), + SinglePartition)), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 1), + true) + + checkSatisfied( + HashPartitioning(Seq($"a", $"b"), 10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + HashPartitioning(Seq($"a", $"b", $"c"), 5), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + SinglePartition, + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + RoundRobinPartitioning(10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + + checkSatisfied( + UnknownPartitioning(10), + StatefulOpClusteredDistribution(Seq($"a", $"b", $"c"), 10), + false) + } }