Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-38124][SQL][SS] Introduce StatefulOpClusteredDistribution and apply to stream-stream join #35419

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,34 @@ case class ClusteredDistribution(
}
}

/**
* Represents the requirement of distribution on the stateful operator.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need to put "structured streaming" before "stateful operator"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "stateful" is already representing the streaming context, but no big deal if we repeat here.

*
* Each partition in stateful operator initializes state store(s), which are independent with state
* store(s) in other partitions. Since it is not possible to repartition the data in state store,
* Spark should make sure the physical partitioning of the stateful operator is unchanged across
* Spark versions. Violation of this requirement may bring silent correctness issue.
*
* Since this distribution relies on [[HashPartitioning]] on the physical partitioning of the
* stateful operator, only [[HashPartitioning]] can satisfy this distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to also explain briefly this only applies to StreamingSymmetricHashJoinExec now, and the challenge to apply it to other stateful operators? Maybe we can also file a JIRA for other stateful operators and leave a TODO here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where to leave a comment is the issue. It is unlikely that we often look at StatefulOpClusteredDistribution - probably giving more chance to get attention if we put a comment on every stateful operators wherever using ClusteredDistribution. Totally redundant, but gives a sign of warning whenever they try to change it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah either works for me. The comment is also non-blocking for this PR, as this is an improvement for documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only [[HashPartitioning]] can satisfy this distribution. -> only [[HashPartitioning]], and [[PartitioningCollection]] of [[HashPartitioning]] can satisfy this distribution. ?

*/
case class StatefulOpClusteredDistribution(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new name :) thanks for making it more specific.

Do we also need to update HashShuffleSpec so that two HashPartitionings can be compatible with each other when checking against StatefulOpClusteredDistributions? this is the previous behavior where Spark would avoid shuffle if both sides of the streaming join are co-partitioned.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to update HashShuffleSpec so that two HashPartitionings can be compatible with each other when checking against StatefulOpClusteredDistributions? this is the previous behavior where Spark would avoid shuffle if both sides of the streaming join are co-partitioned.

Each input must follow the required distribution provided from stateful operator to respect the requirement of state partitioning. State partitioning is the first class, so even both sides of the streaming join are co-partitioned, Spark must perform shuffle if they don't match with state partitioning. (If that was the previous behavior, we broke something at some time point.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this PR will skip shuffle if both sides of a streaming join are co-partitioned. In EnsureRequirements, we currently mainly do two things:

  1. check if output partitioning can satisfy the required distribution
  2. if there are two children, check if they are compatible with each other, and insert shuffle if not.

In the step 2) we'd only consider ClusteredDistribution at the moment, so in case of StatefulOpClusteredDistributions this step is simply skipped. Consequently, Spark will skip shuffle even if only step 1) is successful.

State partitioning is the first class, so even both sides of the streaming join are co-partitioned, Spark must perform shuffle if they don't match with state partitioning.

I'm not quite sure about this. Shouldn't we retain the behavior before #32875? Quoting the comment from @cloud-fan:

I think this is kind of a potential bug. Let's say that we have 2 tables that can report hash partitioning optionally (e.g. controlled by a flag). Assume a streaming query is first run with the flag off, which means the tables do not report hash partitioning, then Spark will add shuffles before the stream-stream join, and the join state (steaming checkpoint) is partitioned by Spark's murmur3 hash function. Then we restart the streaming query with the flag on, and the 2 tables report hash partitioning (not the same as Spark's murmur3). Spark will not add shuffles before stream-stream join this time, and leads to wrong result, because the left/right join child is not co-partitioned with the join state in the previous run.

If we respect co-partitioning and avoid shuffle before #32875 but start shuffle after this PR, I think similar issue like described in the comment can happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. check if output partitioning can satisfy the required distribution

For stream-stream join, once each input satisfy the required "hash" distribution of each, they will be co-partitioned. stream-stream join must guarantee this.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem brought up because ClusteredDistribution has much more relaxed requirement; what we really need to require for "any" stateful operator including stream-stream join is that for all children a specific tuple having specific grouping key must be bound to the deterministic partition "ID", which only HashClusteredDistribution could guarantee.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one behavior difference between this PR and the state before #32875 is that, previously, we'd also check spark.sql.shuffle.partitions and insert shuffle if there's not enough parallelism from the input. However, this PR doesn't do that since it skips the step 2) above.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HashClusteredDistribution also has a requirement of the number of partitions, so step 1) should fulfill it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, StatefulOpClusteredDistributions is very strict and requires numPartitions as well. I don't think we need extra co-partition check for it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it should be good then!

expressions: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution {
require(
expressions != Nil,
"The expressions for hash of a StatefulOpClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is there any chance we specify empty requiredNumPartitions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In sense of "defensive programming", it shouldn't. I just didn't change the implementation of HashClusteredPartition, but now I think it worths to do.

s"This StatefulOpClusteredDistribution requires ${requiredNumPartitions.get} " +
s"partitions, but the actual number of partitions is $numPartitions.")
HashPartitioning(expressions, numPartitions)
}
}

/**
* Represents data where tuples have been ordered according to the `ordering`
* [[Expression Expressions]]. Its requirement is defined as the following:
Expand Down Expand Up @@ -200,6 +228,11 @@ case object SinglePartition extends Partitioning {
* Represents a partitioning where rows are split up across partitions based on the hash
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*
* Since [[StatefulOpClusteredDistribution]] relies on this partitioning and Spark requires
* stateful operators to retain the same physical partitioning during the lifetime of the query
* (including restart), the result of evaluation on `partitionIdExpression` must be unchanged
* across Spark versions. Violation of this requirement may bring silent correctness issue.
Comment on lines +237 to +238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we enforce this assumption in unit test as well? e.g. in StreamingJoinSuite. It's great to highlight in comment here, but people always forget and the unit test will fail loudly when we introduce any invalid change.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a test for verifying this, although it is not exhaustive.

test("streaming join should require HashClusteredDistribution from children") {
val input1 = MemoryStream[Int]
val input2 = MemoryStream[Int]
val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b)
val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b)
val joined = df1.join(df2, Seq("a", "b")).select('a)
testStream(joined)(
AddData(input1, 1.to(1000): _*),
AddData(input2, 1.to(1000): _*),
CheckAnswer(1.to(1000): _*),
Execute { query =>
// Verify the query plan
def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = {
expressions.flatMap {
case ref: AttributeReference => Some(ref.name)
}
}
val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)
assert(query.lastExecution.executedPlan.collect {
case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _,
ShuffleExchangeExec(opA: HashPartitioning, _, _),
ShuffleExchangeExec(opB: HashPartitioning, _, _))
if partitionExpressionsColumns(opA.expressions) === Seq("a", "b")
&& partitionExpressionsColumns(opB.expressions) === Seq("a", "b")
&& opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j
}.size == 1)
})
}

If we want to be exhaustive, I can make a combination of repartitions which could have not triggered shuffle with hash partitioning against joining keys if stream-stream join uses ClusteredDistribution. It may not be exhaustive for future-proof indeed.

Instead, if we are pretty sure StateOpClusteredDistribution would work as expected, we can simply check the required child distribution of the physical plan of stream-stream join, and additionally check the output partitioning of each child to be HashPartitioning with joining keys (this effectively verifies StateOpClusteredDistribution indeed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh actually I was referring to the assumption:

  • HashPartitioning.partitionIdExpression has to be exactly Pmod(new Murmur3Hash(expressions), Literal(numPartitions)).

It would be just to add some logic to check opA/opB.partitionIdExpression for the opA/opB at Line 598/599. I can also do it later if it's not clear to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We checked HashPartitioning and partitionExpression here - the remaining is partitionIdExpression, which is the implementation of HashPartitioning.

That said, it would be nice if we have a separate test against HashPartitioning if we don't have one. Could you please check and craft one if we don't have it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can add one later this week.

*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
extends Expression with Partitioning with Unevaluable {
Expand All @@ -211,6 +244,10 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case h: StatefulOpClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, _) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -46,6 +47,7 @@ object AggUtils {
}

private def createAggregate(
requiredChildDistributionOption: Option[Seq[Distribution]] = None,
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
groupingExpressions: Seq[NamedExpression] = Nil,
aggregateExpressions: Seq[AggregateExpression] = Nil,
Expand All @@ -59,6 +61,7 @@ object AggUtils {

if (useHash && !forceSortAggregate) {
HashAggregateExec(
requiredChildDistributionOption = requiredChildDistributionOption,
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
Expand All @@ -72,6 +75,7 @@ object AggUtils {

if (objectHashEnabled && useObjectHash && !forceSortAggregate) {
ObjectHashAggregateExec(
requiredChildDistributionOption = requiredChildDistributionOption,
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
Expand All @@ -81,6 +85,7 @@ object AggUtils {
child = child)
} else {
SortAggregateExec(
requiredChildDistributionOption = requiredChildDistributionOption,
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
Expand Down Expand Up @@ -299,12 +304,16 @@ object AggUtils {
child = child)
}

// This is used temporarily to pick up the required child distribution for the stateful
// operator.
val tempRestored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion,
partialAggregate)

val partialMerged1: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
createAggregate(
requiredChildDistributionExpressions =
Some(groupingAttributes),
requiredChildDistributionOption = Some(tempRestored.requiredChildDistribution),
groupingExpressions = groupingAttributes,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
Expand All @@ -321,8 +330,7 @@ object AggUtils {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
createAggregate(
requiredChildDistributionExpressions =
Some(groupingAttributes),
requiredChildDistributionOption = Some(restored.requiredChildDistribution),
groupingExpressions = groupingAttributes,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
Expand All @@ -349,7 +357,7 @@ object AggUtils {
val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)

createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
requiredChildDistributionOption = Some(restored.requiredChildDistribution),
groupingExpressions = groupingAttributes,
aggregateExpressions = finalAggregateExpressions,
aggregateAttributes = finalAggregateAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtil
*/
trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning {
def requiredChildDistributionExpressions: Option[Seq[Expression]]
def requiredChildDistributionOption: Option[Seq[Distribution]]
def groupingExpressions: Seq[NamedExpression]
def aggregateExpressions: Seq[AggregateExpression]
def aggregateAttributes: Seq[Attribute]
Expand Down Expand Up @@ -90,10 +91,14 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
requiredChildDistributionOption match {
case Some(dist) => dist.toList
case _ =>
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
}

Expand All @@ -102,7 +107,8 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
*/
def toSortAggregate: SortAggregateExec = {
SortAggregateExec(
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
requiredChildDistributionOption, requiredChildDistributionExpressions, groupingExpressions,
aggregateExpressions, aggregateAttributes, initialInputBufferOffset, resultExpressions,
child)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
Expand All @@ -44,6 +45,7 @@ import org.apache.spark.util.Utils
* Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size.
*/
case class HashAggregateExec(
requiredChildDistributionOption: Option[Seq[Distribution]],
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -58,6 +59,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
* }}}
*/
case class ObjectHashAggregateExec(
requiredChildDistributionOption: Option[Seq[Distribution]],
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand All @@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.SQLConf
* Sort-based aggregate operator.
*/
case class SortAggregateExec(
requiredChildDistributionOption: Option[Seq[Distribution]],
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, StatefulOpClusteredDistribution}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state._
Expand Down Expand Up @@ -93,8 +93,8 @@ case class FlatMapGroupsWithStateExec(
* to have the same grouping so that the data are co-lacated on the same task.
*/
override def requiredChildDistribution: Seq[Distribution] = {
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) ::
StatefulOpClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
StatefulOpClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) ::
Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ case class StreamingSymmetricHashJoinExec(
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
StatefulOpClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
StatefulOpClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
Copy link
Member

@viirya viirya Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is other ClusteredDistribution usages in statefulOperators, e.g. ClusteredDistribution, do we need to update them too? As they are also stateful operators, they also need strict partition requirement?

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer the long comment thread #35419 (comment)

We have to fix them, but we should have a plan to avoid introducing "silently broken" on existing queries. We need more time to think through how to address the "already broken" thing. They seem to be broken from their introduction (Spark 2.2+), so it could be possible someone is even leveraging the relaxed requirement as a "feature", despite it would be very risky if they tried to adjust partitioning by theirselves. Even for this case we can't simply break their query.

I'll create a new JIRA ticket, and/or initiate discussion thread in dev@ regarding this. I need some time to build a plan (with options) to address this safely.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could be possible someone is even leveraging the relaxed requirement as a "feature"

Suppose they have a stream of event logs having (userId, time, blabla), and do time-window aggregation like this:

df
  .withWatermark("time", "1 hour")
  .groupBy("userId", window("time", "10 minutes"))
  .agg(count("*"))

groupBy won't trigger shuffle for various output partitionings of df, since streaming aggregation requires ClusteredDistribution. The thing is, it could be from the intention to 1) reduce shuffle in any way, or 2) try to control the partitioning to deal with skew. (I can't easily think of skew from applying hash function against "grouping keys + time window", but once they see it, they will try to fix it. ...Technically saying, they must not try to fix it as state partitioning will be no longer the same with operator's partitioning...)

Both are very risky (as of now, changing the partitioning during query lifetime would lead to correctness issue), but it's still from users' intention and they already did it anyway so we can't simply enforce the partitioning and silently break this again.

Furthermore, we seem to allow data source to produce output partitioning by itself, which can satisfy ClusteredDistribution. This is still very risky for stateful operator's perspective, but once the output partitioning is guaranteed to be not changed, it's still a great change to reduce (unnecessary) shuffle.
(Just saying hypothetically; stateful operator has to require specific output partitioning once the state is built, so it's unlikely that we can leverage the partitioning of data source. We may find a way later but not now.)


override def output: Seq[Attribute] = joinType match {
case _: InnerLike => left.output ++ right.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning, StatefulOpClusteredDistribution}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -337,7 +337,7 @@ case class StateStoreRestoreExec(
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if this change could introduce extra shuffle for streaming aggregate. Previously the operator requires ClusteredDistribution, and right now it requires StatefulOpClusteredDistribution/HashClusteredDistribution.

ClusteredDistribution is more relaxed than HashClusteredDistribution in the sense that a HashPartitioning(c1) can satisfy ClusteredDistribution(c1, c2), but cannot satisfy HashClusteredDistribution(c1, c2). In short, ClusteredDistribution allows child to be hash-partitioned on subset of required keys. So for aggregate, if the plan is already shuffled on subset of group-by columns, Spark will not add a shuffle again before group-by.

For example:

MemoryStream[(Int, Int)].toDF()
  .repartition($"_1")
  .groupBy($"_1", $"_2")
  .agg(count("*"))
  .as[(Int, Int, Long)]

and the query plan:

WriteToDataSourceV2 org.apache.spark.sql.execution.streaming.sources.MicroBatchWrite@5940f7c2, org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$Lambda$1940/1200613952@4861dac3
+- *(4) HashAggregate(keys=[_1#588, _2#589], functions=[count(1)], output=[_1#588, _2#589, count(1)#596L])
   +- StateStoreSave [_1#588, _2#589], state info [ checkpoint = file:/private/var/folders/y5/hnsw8mz93vs57ngcd30y6y9c0000gn/T/streaming.metadata-0d7cb004-92dd-4b0d-9d90-5a65c0d2934c/state, runId = 68598bd1-cf35-4bf7-a167-5f73dc9f4d84, opId = 0, ver = 0, numPartitions = 5], Complete, 0, 1
      +- *(3) HashAggregate(keys=[_1#588, _2#589], functions=[merge_count(1)], output=[_1#588, _2#589, count#663L])
         +- StateStoreRestore [_1#588, _2#589], state info [ checkpoint = file:/private/var/folders/y5/hnsw8mz93vs57ngcd30y6y9c0000gn/T/streaming.metadata-0d7cb004-92dd-4b0d-9d90-5a65c0d2934c/state, runId = 68598bd1-cf35-4bf7-a167-5f73dc9f4d84, opId = 0, ver = 0, numPartitions = 5], 1
            +- *(2) HashAggregate(keys=[_1#588, _2#589], functions=[merge_count(1)], output=[_1#588, _2#589, count#663L])
               +- *(2) HashAggregate(keys=[_1#588, _2#589], functions=[partial_count(1)], output=[_1#588, _2#589, count#663L])
                  +- Exchange hashpartitioning(_1#588, 5), REPARTITION_BY_COL, [id=#2008]
                     +- *(1) Project [_1#588, _2#589]
                        +- MicroBatchScan[_1#588, _2#589] MemoryStreamDataSource

One can argue the previous behavior for streaming aggregate is not wrong. As long as all rows for same keys are colocated in same partition, StateStoreRestore/Store should output correct answer for streaming aggregate. If we make the change here, I assume one extra shuffle on ($"_1", $"_2") would be introduced, and it might yield incorrect result when running the new query plan against the existing state store?

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't define "repartition just before stateful operator" as "unsupported operation across query lifetime", no?

The thing is that once the query is run, the partitioning of stateful operator must not be changed during lifetime. Since we don't store the information of partitioning against stateful operator in the checkpoint, we have no way around other than enforcing the partitioning of stateful operator as the "one" what we basically expect.

As I said in #32875, there is a room for improvement, but the effort on improvement must be performed after we fix this issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't define "repartition just before stateful operator" as "unsupported operation across query lifetime", no?

I ran the query in StreamingAggregationSuite.scala and it seems fine. I briefly checked UnsupportedOperationChecker.scala and didn't find we disallow "repartition just before stateful operator".

The thing is that once the query is run, the partitioning of stateful operator must not be changed during lifetime. Since we don't store the information of partitioning against stateful operator in the checkpoint, we have no way around other than enforcing the partitioning of stateful operator as the "one" what we basically expect.

I agree with you @HeartSaVioR. I want to raise a concern here it might change the query plan for certain streaming aggregate query (as above synthetic query), and it could break existing state store when running with next Spark 3.3 code, based on my limited understanding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make the change here, I assume one extra shuffle on ($"_1", $"_2") would be introduced, and it might yield incorrect result when running the new query plan against the existing state store?

Unfortunately yes. We may need to craft some tools to analyze the state and repartition if the partitioning is already messed up. But leaving this as it is would bring more chances to let users' state be indeterministic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't define "repartition just before stateful operator" as "unsupported operation across query lifetime", no?

I ran the query in StreamingAggregationSuite.scala and it seems fine. I briefly checked UnsupportedOperationChecker.scala and didn't find we disallow "repartition just before stateful operator".

That is the problem we have. We didn't disallow the case where it brings silent correctness issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the problem we have. We didn't disallow the case where it brings silent correctness issue.

This is just a synthetic example I composed to verify my theory. But I think it might break for more cases, such as GROUP BY c1, c2 after JOIN ON c1. I am trying to say it would break for the queries which are partitioned on subset of group-by keys.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even nowadays multiple stateful operators don't work properly due to the global watermark, so we don't need to worry about the partitioning between stateful operators. We just need to worry about the partitioning between upstream (in most cases non stateful) and the stateful operator.

}
}

Expand Down Expand Up @@ -496,7 +496,7 @@ case class StateStoreSaveExec(
if (keyExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
}
}

Expand Down Expand Up @@ -573,7 +573,8 @@ case class SessionWindowStateStoreRestoreExec(
}

override def requiredChildDistribution: Seq[Distribution] = {
ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil
StatefulOpClusteredDistribution(keyWithoutSessionExpressions,
stateInfo.map(_.numPartitions)) :: Nil
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
Expand Down Expand Up @@ -684,7 +685,7 @@ case class SessionWindowStateStoreSaveExec(
override def outputPartitioning: Partitioning = child.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
}

override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
Expand Down Expand Up @@ -742,7 +743,7 @@ case class StreamingDeduplicateExec(

/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(
executedPlan.find {
case WholeStageCodegenExec(
HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true
HashAggregateExec(_, _, _, _, _, _, _, _: LocalTableScanExec)) => true
case _ => false
}.isDefined,
"LocalTableScanExec should be within a WholeStageCodegen domain.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
CheckNewAnswer((5, 10, 5, 15, 5, 25)))
}

test("streaming join should require HashClusteredDistribution from children") {
test("streaming join should require StatefulOpClusteredDistribution from children") {
val input1 = MemoryStream[Int]
val input2 = MemoryStream[Int]

Expand Down