Skip to content

Commit

Permalink
[SPARK-50091][SQL] Handle case of aggregates in left-hand operand of …
Browse files Browse the repository at this point in the history
…IN-subquery

### What changes were proposed in this pull request?

This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.:
```
Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))#40]
+- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L)
   :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L]
   :  +- LocalRelation [col1#32, col2#33]
   +- LocalRelation [c2#39L]
```
`sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression.

### Why are the changes needed?

The following query fails:
```
create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);
create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1);

select col1, sum(col2) in (select c2 from v1)
from v2 group by col1;
```
It fails with this error:
```
[INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000
```
With SPARK_TESTING=1, it fails with this error:
```
[PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan:
Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))#19]
+- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L)
   :- LocalRelation [col1#11, col2#12]
   +- LocalRelation [c2#18L]
```
The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression.

The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression.

This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions.

### Does this PR introduce _any_ user-facing change?

No, other than allowing this type of query to succeed.

### How was this patch tested?

New unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48627 from bersprockets/aggregate_in_set_issue.

Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit e02ff1c)
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
bersprockets authored and cloud-fan committed Jan 23, 2025
1 parent 696a541 commit eb128b0
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -115,6 +116,26 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
exprs.exists { expr =>
exprContainsAggregateInSubquery(expr)
}
}

def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
expr.exists {
case InSubquery(values, _) =>
values.exists { v =>
v.exists {
case _: AggregateExpression => true
case _ => false
}
}
case _ => false;
}
}


def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
case Filter(condition, child)
Expand Down Expand Up @@ -246,46 +267,106 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
//
// If an Aggregate node contains such an IN-subquery, this handler will pull up all
// expressions from the Aggregate node into a new Project node. The new Project node
// will then be handled by the Unary node handler.
//
// The Unary node handler uses the left-hand side of the IN-subquery in a
// join condition. Thus, without this pre-transformation, the join condition
// contains an aggregate, which is illegal. With this pre-transformation, the
// join condition contains an attribute from the left-hand side of the
// IN-subquery contained in the Project node.
//
// For example:
//
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
// FROM v2;
//
// The above query has this plan on entry to RewritePredicateSubquery#apply:
//
// Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- LocalRelation [col2#18, col3#19]
//
// Note that the Aggregate node contains the IN-subquery and the left-hand
// side of the IN-subquery is an aggregate expression sum(col2#18)).
//
// This handler transforms the above plan into the following:
// scalastyle:off line.size.limit
//
// Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13]
// : +- LocalRelation [c3#28L]
// +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L]
// +- LocalRelation [col2#18, col3#19]
//
// scalastyle:on
// Note that both the IN-subquery and the greater-than expressions have been
// pulled up into the Project node. These expressions use attributes
// (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations
// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
case p @ PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child)
if exprsContainsAggregateInSubquery(p.expressions) =>
val aggExprs = aggregateExpressions.map(
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
val aggExprIds = aggExprs.map(_.exprId).toSet
val resExprs = resultExpressions.map(_.transform {
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
a.withName("_aggregateexpression")
}.asInstanceOf[NamedExpression])
// Rewrite the projection and the aggregate separately and then piece them together.
val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
val newProj = Project(resExprs, newAgg)
handleUnaryNode(newProj)

case u: UnaryNode if u.expressions.exists(
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u)
}

/**
* Handle the unary node case
*/
private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
var newChild = u.child
var introducedAttrs = Seq.empty[Attribute]
val updatedNode = u.mapExpressions(expr => {
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
newChild = p
introducedAttrs ++= newAttrs
// The newExpr can not be None
newExpr.get
}).withNewChildren(Seq(newChild))
updatedNode match {
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
// If we have introduced new `exists`-attributes that are referenced by
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
// first() aggregate function. first() is Spark's executable version of any_value()
// aggregate function.
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
// functions that are not in grouping expressions.
// Note that the same `exists` attr will never appear in groupingExpressions due to
// PullOutGroupingExpressions rule.
// Also note: the value of `exists` is functionally determined by grouping expressions,
// so applying any aggregate function is semantically safe.
val aggFunctionReferences = a.aggregateExpressions.
flatMap(extractAggregateExpressions).
flatMap(_.references).toSet
val nonAggFuncReferences =
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)

// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
aggExpr.transformUp {
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
new First(attr).toAggregateExpression()
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = newAggregateExpressions)
case _ => updatedNode
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.LongType


class RewriteSubquerySuite extends PlanTest {
Expand Down Expand Up @@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
Optimize.executeAndTrack(query.analyze, tracker)
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0)
}

test("SPARK-50091: Don't put aggregate expression in join condition") {
val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
val optimized = Optimize.execute(plan.analyze)
val aggregate = relation2
.select($"col2")
.groupBy()(sum($"col2").as("_aggregateexpression"))
val correctAnswer = aggregate
.join(relation1.select(Cast($"c3", LongType).as("c3")),
ExistenceJoin($"exists".boolean.withNullability(false)),
Some($"_aggregateexpression" === $"c3"))
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
comparePlans(optimized, correctAnswer)
}
}
30 changes: 30 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
checkAnswer(df3, Row(7))
}
}

test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
withView("v1", "v2") {
Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
.toDF("c1", "c2", "c3")
.createOrReplaceTempView("v1")
Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
.toDF("col1", "col2", "col3")
.createOrReplaceTempView("v2")

val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1")
checkAnswer(df1,
Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)

val df2 = sql("""SELECT
| col1,
| SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x
|FROM v2 GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df2,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)

val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x
|FROM v2
|GROUP BY col1
|ORDER BY col1""".stripMargin)
checkAnswer(df3,
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
}
}
}

0 comments on commit eb128b0

Please sign in to comment.