Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43838][SQL] Fix subquery on single table with having clause can't be optimized #41347

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case ScalarSubquery(query, outerAttrs, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
expr.failAnalysis(
errorClass = "INVALID_SUBQUERY_EXPRESSION." +
"SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN",
messageParameters = Map("number" -> query.output.size.toString))
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
expr.origin)
}

if (outerAttrs.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
}

private def existDuplicatedExprId(
existingRelations: mutable.HashSet[RelationWrapper],
plan: RelationWrapper): Boolean = {
existingRelations.filter(_.cls == plan.cls)
.exists(_.outputAttrIds.intersect(plan.outputAttrIds).nonEmpty)
}

/**
* Deduplicate any duplicated relations of a LogicalPlan
* @param existingRelations the known unique relations for a LogicalPlan
Expand All @@ -105,6 +112,193 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
(m, false)
}

case p @ Project(_, child) if p.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
p, changed)
var newProject = newPlan.asInstanceOf[Project]
val aliaesAttr = findAliases(newProject.projectList)
if (aliaesAttr.nonEmpty) {
val planWrapper = RelationWrapper(p.getClass, aliaesAttr.map(_.exprId.id).toSeq)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newProject = newProject.copy(newProject.projectList.map {
case a: Alias => a.newInstance()
case other => other
})
newProject.copyTagsFrom(p)
(newProject, true)
} else {
existingRelations.add(planWrapper)
(newProject, planChanged)
}
} else {
(newProject, planChanged)
}

case s @ SerializeFromObject(_, child) if s.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newSer = newPlan.asInstanceOf[SerializeFromObject]
val planWrapper = RelationWrapper(newSer.getClass, newSer.serializer.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newSer = newSer.copy(newSer.serializer.map(_.newInstance()))
newSer.copyTagsFrom(s)
(newSer, true)
} else {
existingRelations.add(planWrapper)
(newSer, planChanged)
}

case f @ FlatMapGroupsInPandas(_, _, _, child) if f.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newFlatMap = newPlan.asInstanceOf[FlatMapGroupsInPandas]
val planWrapper = RelationWrapper(newFlatMap.getClass,
newFlatMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newFlatMap = newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))
newFlatMap.copyTagsFrom(plan)
(newFlatMap, true)
} else {
existingRelations.add(planWrapper)
(newFlatMap, planChanged)
}

case f @ FlatMapCoGroupsInPandas(_, _, _, _, left, right) if f.resolved =>
val (leftRenew, leftChanged) = renewDuplicatedRelations(existingRelations, left)
val (rightRenew, rightChanged) = renewDuplicatedRelations(existingRelations, right)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(leftRenew,
rightRenew), plan, leftChanged || rightChanged)

var newFlatMap = newPlan.asInstanceOf[FlatMapCoGroupsInPandas]
val planWrapper = RelationWrapper(newFlatMap.getClass,
newFlatMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newFlatMap = newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))
newFlatMap.copyTagsFrom(plan)
(newFlatMap, true)
} else {
existingRelations.add(planWrapper)
(newFlatMap, planChanged)
}

case m @ MapInPandas(_, _, child, _) if m.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newMap = newPlan.asInstanceOf[MapInPandas]
val planWrapper = RelationWrapper(newMap.getClass,
newMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newMap = newMap.copy(output = newMap.output.map(_.newInstance()))
newMap.copyTagsFrom(plan)
(newMap, true)
} else {
existingRelations.add(planWrapper)
(newMap, planChanged)
}

case p @ PythonMapInArrow(_, _, child, _) if p.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newMap = newPlan.asInstanceOf[PythonMapInArrow]
val planWrapper = RelationWrapper(newMap.getClass,
newMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newMap = newMap.copy(output = newMap.output.map(_.newInstance()))
newMap.copyTagsFrom(plan)
(newMap, true)
} else {
existingRelations.add(planWrapper)
(newMap, planChanged)
}

case a @ AttachDistributedSequence(_, child) if a.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newAttach = newPlan.asInstanceOf[AttachDistributedSequence]
val planWrapper = RelationWrapper(newAttach.getClass,
newAttach.producedAttributes.map(_.exprId.id).toSeq)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newAttach = newAttach.copy(sequenceAttr = newAttach.producedAttributes.map(_.newInstance())
.head)
newAttach.copyTagsFrom(plan)
(newAttach, true)
} else {
existingRelations.add(planWrapper)
(newAttach, planChanged)
}

case g @ Generate(_, _, _, _, _, child) if g.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newGenerate = newPlan.asInstanceOf[Generate]
val planWrapper = RelationWrapper(newGenerate.getClass,
newGenerate.generatorOutput.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newGenerate = newGenerate.copy(generatorOutput = newGenerate.generatorOutput.map(
_.newInstance()))
newGenerate.copyTagsFrom(plan)
(newGenerate, true)
} else {
existingRelations.add(planWrapper)
(newGenerate, planChanged)
}

case e @ Expand(_, _, child) if e.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newExpand = newPlan.asInstanceOf[Expand]
val planWrapper = RelationWrapper(newExpand.getClass,
newExpand.producedAttributes.map(_.exprId.id).toSeq)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newExpand = newExpand.copy(output = newExpand.output.map(_.newInstance()))
newExpand.copyTagsFrom(plan)
(newExpand, true)
} else {
existingRelations.add(planWrapper)
(newExpand, planChanged)
}

case w @ Window(_, _, _, child) if w.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newWindow = newPlan.asInstanceOf[Window]
val planWrapper = RelationWrapper(newWindow.getClass,
newWindow.windowExpressions.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newWindow = newWindow.copy(windowExpressions =
newWindow.windowExpressions.map(_.newInstance()))
newWindow.copyTagsFrom(plan)
(newWindow, true)
} else {
existingRelations.add(planWrapper)
(newWindow, planChanged)
}

case s @ ScriptTransformation(_, _, child, _) if s.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newScript = newPlan.asInstanceOf[ScriptTransformation]
val planWrapper = RelationWrapper(newScript.getClass,
newScript.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newScript = newScript.copy(output = newScript.output.map(_.newInstance()))
newScript.copyTagsFrom(plan)
(newScript, true)
} else {
existingRelations.add(planWrapper)
(newScript, planChanged)
}

case plan: LogicalPlan =>
var planChanged = false
val newPlan = if (plan.children.nonEmpty) {
Expand All @@ -117,37 +311,49 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
}

val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
}
val (newPlan, changed) = getNewPlanWithNewChildren(existingRelations, newChildren.toArray,
plan, planChanged)
planChanged |= changed
newPlan
} else {
plan
}
(newPlan, planChanged)
}

if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(
plan
.children
.flatMap(_.output).zip(newChildren.flatMap(_.output))
.filter { case (a1, a2) => a1.exprId != a2.exprId }
)
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
}
} else {
planWithNewSubquery.withNewChildren(newChildren.toSeq)
}
} else {
private def getNewPlanWithNewChildren(
existingRelations: mutable.HashSet[RelationWrapper],
newChildren: Array[LogicalPlan],
plan: LogicalPlan, changed: Boolean): (LogicalPlan, Boolean) = {
var planChanged = changed
val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
}

val newPlan = if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(
plan
.children
.flatMap(_.output).zip(newChildren.flatMap(_.output))
.filter { case (a1, a2) => a1.exprId != a2.exprId }
)
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
}
} else {
plan
planWithNewSubquery.withNewChildren(newChildren.toSeq)
}
(newPlan, planChanged)
} else {
plan
}
(newPlan, planChanged)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ case class ScalarSubquery(
mayHaveCountBug: Option[Boolean] = None)
extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable {
override def dataType: DataType = {
assert(plan.schema.fields.nonEmpty, "Scalar subquery should have only one column")
if (!plan.schema.fields.nonEmpty) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure the AnalysisException will be throw, not AssertionError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we really reach this code branch?

Copy link
Member Author

@Hisoka-X Hisoka-X Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Usually this error will be thrown by checkAnalysis, but we may call datatype in DeduplicateRelations to cause this exception to be thrown. This change ensures that the thrown exception is consistent.

Change before:

Caused by: sbt.ForkMain$ForkError: java.lang.AssertionError: assertion failed: Scalar subquery should have only one column
	at scala.Predef$.assert(Predef.scala:223)
	at org.apache.spark.sql.catalyst.expressions.ScalarSubquery.dataType(subquery.scala:274)
	at org.apache.spark.sql.catalyst.expressions.Alias.toAttribute(namedExpressions.scala:194)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$$anonfun$findAliases$1.applyOrElse(DeduplicateRelations.scala:530)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$$anonfun$findAliases$1.applyOrElse(DeduplicateRelations.scala:530)
	at scala.PartialFunction.$anonfun$runWith$1$adapted(PartialFunction.scala:145)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.collect(TraversableLike.scala:407)
	at scala.collection.TraversableLike.collect$(TraversableLike.scala:405)
	at scala.collection.AbstractTraversable.collect(Traversable.scala:108)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.findAliases(DeduplicateRelations.scala:530)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.org$apache$spark$sql$catalyst$analysis$DeduplicateRelations$$renewDuplicatedRelations(DeduplicateRelations.scala:120)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.apply(DeduplicateRelations.scala:40)
	at org.apache.spark.sql.catalyst.analysis.DeduplicateRelations$.apply(DeduplicateRelations.scala:38)

throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(plan.schema.fields.length,
origin)
}
plan.schema.fields.head.dataType
}
override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
messageParameters = Map("function" -> funcStr))
}

def subqueryReturnMoreThanOneColumn(number: Int, origin: Origin): Throwable = {
new AnalysisException(
errorClass = "INVALID_SUBQUERY_EXPRESSION." +
"SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN",
origin = origin,
messageParameters = Map("number" -> number.toString))
}

def unsupportedCorrelatedReferenceDataTypeError(
expr: Expression,
dataType: DataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,25 +437,6 @@ class LeftSemiAntiJoinPushDownSuite extends PlanTest {
}
}

Seq(LeftSemi, LeftAnti).foreach { case jt =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test unnecessary. Because we can deduplicate those attributes in anti-join / semi-join is a self-join. Please refer #39131

test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
val aggregation = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")($"id", sum($"c").as("sum"))

// reference "b" exists in left leg, and the children of the right leg of the join
val originalQuery = aggregation.select(($"id" + 1).as("id_plus_1"), $"sum")
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select($"b".as("id"), $"c")
.groupBy($"id")(($"id" + 1).as("id_plus_1"), sum($"c").as("sum"))
.join(aggregation, joinType = jt, condition = Some($"id" === $"id_plus_1"))
.analyze
comparePlans(optimized, correctAnswer)
}
}

Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {
Expand Down
Loading