Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43838][SQL] Fix subquery on single table with having clause can't be optimized #41347

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case ScalarSubquery(query, outerAttrs, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
expr.failAnalysis(
errorClass = "INVALID_SUBQUERY_EXPRESSION." +
"SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN",
messageParameters = Map("number" -> query.output.size.toString))
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
expr.origin)
}

if (outerAttrs.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
}

private def existDuplicatedExprId(
existingRelations: mutable.HashSet[RelationWrapper],
plan: RelationWrapper): Boolean = {
existingRelations.filter(_.cls == plan.cls)
.exists(_.outputAttrIds.intersect(plan.outputAttrIds).nonEmpty)
}

/**
* Deduplicate any duplicated relations of a LogicalPlan
* @param existingRelations the known unique relations for a LogicalPlan
Expand All @@ -95,59 +102,161 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
case p: LogicalPlan if p.isStreaming => (plan, false)

case m: MultiInstanceRelation =>
val planWrapper = RelationWrapper(m.getClass, m.output.map(_.exprId.id))
if (existingRelations.contains(planWrapper)) {
val newNode = m.newInstance()
newNode.copyTagsFrom(m)
(newNode, true)
} else {
existingRelations.add(planWrapper)
(m, false)
}
deduplicateAndRenew[LogicalPlan with MultiInstanceRelation](
existingRelations,
m,
_.output.map(_.exprId.id),
node => node.newInstance().asInstanceOf[LogicalPlan with MultiInstanceRelation])

case p: Project =>
deduplicateAndRenew[Project](
existingRelations,
p,
newProject => findAliases(newProject.projectList).map(_.exprId.id).toSeq,
newProject => newProject.copy(newAliases(newProject.projectList)))

case s: SerializeFromObject =>
deduplicateAndRenew[SerializeFromObject](
existingRelations,
s,
_.serializer.map(_.exprId.id),
newSer => newSer.copy(newSer.serializer.map(_.newInstance())))

case f: FlatMapGroupsInPandas =>
deduplicateAndRenew[FlatMapGroupsInPandas](
existingRelations,
f,
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case f: FlatMapCoGroupsInPandas =>
deduplicateAndRenew[FlatMapCoGroupsInPandas](
existingRelations,
f,
_.output.map(_.exprId.id),
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))

case m: MapInPandas =>
deduplicateAndRenew[MapInPandas](
existingRelations,
m,
_.output.map(_.exprId.id),
newMap => newMap.copy(output = newMap.output.map(_.newInstance())))

case p: PythonMapInArrow =>
deduplicateAndRenew[PythonMapInArrow](
existingRelations,
p,
_.output.map(_.exprId.id),
newMap => newMap.copy(output = newMap.output.map(_.newInstance())))

case a: AttachDistributedSequence =>
deduplicateAndRenew[AttachDistributedSequence](
existingRelations,
a,
_.producedAttributes.map(_.exprId.id).toSeq,
newAttach => newAttach.copy(sequenceAttr = newAttach.producedAttributes
.map(_.newInstance()).head))

case g: Generate =>
deduplicateAndRenew[Generate](
existingRelations,
g,
_.generatorOutput.map(_.exprId.id), newGenerate =>
newGenerate.copy(generatorOutput = newGenerate.generatorOutput.map(_.newInstance())))

case e: Expand =>
deduplicateAndRenew[Expand](
existingRelations,
e,
_.producedAttributes.map(_.exprId.id).toSeq,
newExpand => newExpand.copy(output = newExpand.output.map(_.newInstance())))

case w: Window =>
deduplicateAndRenew[Window](
existingRelations,
w,
_.windowExpressions.map(_.exprId.id),
newWindow => newWindow.copy(windowExpressions =
newWindow.windowExpressions.map(_.newInstance())))

case s: ScriptTransformation =>
deduplicateAndRenew[ScriptTransformation](
existingRelations,
s,
_.output.map(_.exprId.id),
newScript => newScript.copy(output = newScript.output.map(_.newInstance())))

case plan: LogicalPlan =>
var planChanged = false
val newPlan = if (plan.children.nonEmpty) {
val newChildren = mutable.ArrayBuffer.empty[LogicalPlan]
for (c <- plan.children) {
val (renewed, changed) = renewDuplicatedRelations(existingRelations, c)
newChildren += renewed
if (changed) {
planChanged = true
}
}
deduplicate(existingRelations, plan)
}

val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
private def deduplicate(
existingRelations: mutable.HashSet[RelationWrapper],
plan: LogicalPlan): (LogicalPlan, Boolean) = {
var planChanged = false
val newPlan = if (plan.children.nonEmpty) {
val newChildren = mutable.ArrayBuffer.empty[LogicalPlan]
for (c <- plan.children) {
val (renewed, changed) = renewDuplicatedRelations(existingRelations, c)
newChildren += renewed
if (changed) {
planChanged = true
}
}

val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
}

if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(
plan
.children
.flatMap(_.output).zip(newChildren.flatMap(_.output))
.filter { case (a1, a2) => a1.exprId != a2.exprId }
)
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
}
if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(plan.children.flatMap(_.output)
.zip(newChildren.flatMap(_.output)).filter { case (a1, a2) => a1.exprId != a2.exprId })
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewSubquery.withNewChildren(newChildren.toSeq)
planWithNewChildren.rewriteAttrs(attrMap)
}
} else {
plan
planWithNewSubquery.withNewChildren(newChildren.toSeq)
}
} else {
plan
}
} else {
plan
}
(newPlan, planChanged)
}

private def deduplicateAndRenew[T <: LogicalPlan](
existingRelations: mutable.HashSet[RelationWrapper], plan: T,
getExprIds: T => Seq[Long],
copyNewPlan: T => T): (LogicalPlan, Boolean) = {
var (newPlan, planChanged) = deduplicate(existingRelations, plan)
if (newPlan.resolved) {
val exprIds = getExprIds(newPlan.asInstanceOf[T])
if (exprIds.nonEmpty) {
val planWrapper = RelationWrapper(newPlan.getClass, exprIds)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newPlan = copyNewPlan(newPlan.asInstanceOf[T])
newPlan.copyTagsFrom(plan)
(newPlan, true)
} else {
existingRelations.add(planWrapper)
(newPlan, planChanged)
}
} else {
(newPlan, planChanged)
}
} else {
(newPlan, planChanged)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ case class ScalarSubquery(
mayHaveCountBug: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = {
assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
if (!plan.schema.fields.nonEmpty) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure the AnalysisException will be throw, not AssertionError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we really reach this code branch?

Copy link
Member Author

@Hisoka-X Hisoka-X Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Usually this error will be thrown by checkAnalysis, but we may call datatype in DeduplicateRelations to cause this exception to be thrown. This change ensures that the thrown exception is consistent.

Change before:

Caused by: sbt.ForkMain$ForkError: java.lang.AssertionError: assertion failed: Scalar subquery should have only one column
	at scala.Predef$.assert(Predef.scala:223)
	at org.apache.spark.sql.catalyst.expressions.ScalarSubquery.dataType(subquery.scala:274)
	at org.apache.spark.sql.catalyst.expressions.Alias.toAttribute(namedExpressions.scala:194)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$$anonfun$findAliases$1.applyOrElse(DeduplicateRelations.scala:530)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$$anonfun$findAliases$1.applyOrElse(DeduplicateRelations.scala:530)
	at scala.PartialFunction.$anonfun$runWith$1$adapted(PartialFunction.scala:145)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.collect(TraversableLike.scala:407)
	at scala.collection.TraversableLike.collect$(TraversableLike.scala:405)
	at scala.collection.AbstractTraversable.collect(Traversable.scala:108)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.findAliases(DeduplicateRelations.scala:530)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.org$apache$spark$sql$catalyst$analysis$DeduplicateRelations$$renewDuplicatedRelations(DeduplicateRelations.scala:120)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.apply(DeduplicateRelations.scala:40)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.apply(DeduplicateRelations.scala:38)

throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(plan.schema.fields.length,
origin)
}
plan.schema.fields.head.dataType
}
override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
messageParameters = Map("function" -> funcStr))
}

def subqueryReturnMoreThanOneColumn(number: Int, origin: Origin): Throwable = {
new AnalysisException(
errorClass = "INVALID_SUBQUERY_EXPRESSION." +
"SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN",
origin = origin,
messageParameters = Map("number" -> number.toString))
}

def unsupportedCorrelatedReferenceDataTypeError(
expr: Expression,
dataType: DataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,25 +437,6 @@ class LeftSemiAntiJoinPushDownSuite extends PlanTest {
}
}

Seq(LeftSemi, LeftAnti).foreach { case jt =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test unnecessary. Because we can deduplicate those attributes in anti-join / semi-join is a self-join. Please refer #39131

test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
val aggregation = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")($"id", sum($"c").as("sum"))

// reference "b" exists in left leg, and the children of the right leg of the join
val originalQuery = aggregation.select(($"id" + 1).as("id_plus_1"), $"sum")
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")(($"id" + 1).as("id_plus_1"), sum($"c").as("sum"))
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
.analyze
comparePlans(optimized, correctAnswer)
}
}

Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1977,17 +1977,16 @@ Union false, false
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Project [c1#x AS c1#x, c2#x AS c2#x, c1#x AS c1#x, c2#x AS c2#x]
+- Project [c1#x, c2#x, c1#x, c2#x]
+- Join Inner
:- SubqueryAlias spark_catalog.default.t1
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t4
+- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
+- Project [c1#x, c2#x, c1#x, c2#x]
+- Join Inner
:- SubqueryAlias spark_catalog.default.t1
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t4
+- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down Expand Up @@ -2030,27 +2029,26 @@ Union false, false
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Project [c1#x AS c1#x, c2#x AS c2#x, c2#x AS c2#x]
+- Project [c1#x, c2#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Union false, false
: :- Project [c2#x]
: : +- Filter (outer(c1#x) <= c1#x)
: : +- SubqueryAlias spark_catalog.default.t1
: : +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- Project [c2#x]
: +- Filter (c1#x < outer(c1#x))
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t2
+- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
+- Project [c1#x, c2#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Union false, false
: :- Project [c2#x]
: : +- Filter (outer(c1#x) <= c1#x)
: : +- SubqueryAlias spark_catalog.default.t1
: : +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- Project [c2#x]
: +- Filter (c1#x < outer(c1#x))
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t2
+- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,3 +1057,21 @@ Project [c1#xL, c2#xL]
: +- Range (1, 2, step=1, splits=None)
+- SubqueryAlias t1
+- Range (1, 3, step=1, splits=None)


-- !query
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
-- !query analysis
Project [c1#x, c2#x, scalar-subquery#x [c1#x] AS scalarsubquery(c1)#xL]
: +- Filter (cnt#xL = cast(0 as bigint))
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(c1#x) = c1#x)
: +- SubqueryAlias t2
: +- SubqueryAlias t1
: +- View (`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias t1
+- View (`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,6 @@ select * from (
where t1.id = t2.id ) c2
from range (1, 3) t1 ) t
where t.c2 is not null;

-- SPARK-43838: Subquery on single table with having clause
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,12 @@ where t.c2 is not null
struct<c1:bigint,c2:bigint>
-- !query output
1 1


-- !query
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
-- !query schema
struct<c1:int,c2:int,scalarsubquery(c1):bigint>
-- !query output
0 1 NULL
1 2 NULL
Loading