Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Hisoka-X committed Jul 10, 2023
1 parent 56b9f6c commit 0c12ef9
Show file tree
Hide file tree
Showing 9 changed files with 426 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
}

private def existDuplicatedExprId(
existingRelations: mutable.HashSet[RelationWrapper],
plan: RelationWrapper): Boolean = {
existingRelations.filter(_.cls == plan.cls)
.exists(_.outputAttrIds.intersect(plan.outputAttrIds).nonEmpty)
}

/**
* Deduplicate any duplicated relations of a LogicalPlan
* @param existingRelations the known unique relations for a LogicalPlan
Expand All @@ -105,6 +112,193 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
(m, false)
}

case p @ Project(_, child) if p.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
p, changed)
var newProject = newPlan.asInstanceOf[Project]
val aliaesAttr = findAliases(newProject.projectList)
if (aliaesAttr.nonEmpty) {
val planWrapper = RelationWrapper(p.getClass, aliaesAttr.map(_.exprId.id).toSeq)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newProject = newProject.copy(newProject.projectList.map {
case a: Alias => a.newInstance()
case other => other
})
newProject.copyTagsFrom(p)
(newProject, true)
} else {
existingRelations.add(planWrapper)
(newProject, planChanged)
}
} else {
(newProject, planChanged)
}

case s @ SerializeFromObject(_, child) if s.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newSer = newPlan.asInstanceOf[SerializeFromObject]
val planWrapper = RelationWrapper(newSer.getClass, newSer.serializer.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newSer = newSer.copy(newSer.serializer.map(_.newInstance()))
newSer.copyTagsFrom(s)
(newSer, true)
} else {
existingRelations.add(planWrapper)
(newSer, planChanged)
}

case f @ FlatMapGroupsInPandas(_, _, _, child) if f.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newFlatMap = newPlan.asInstanceOf[FlatMapGroupsInPandas]
val planWrapper = RelationWrapper(newFlatMap.getClass,
newFlatMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newFlatMap = newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))
newFlatMap.copyTagsFrom(plan)
(newFlatMap, true)
} else {
existingRelations.add(planWrapper)
(newFlatMap, planChanged)
}

case f @ FlatMapCoGroupsInPandas(_, _, _, _, left, right) if f.resolved =>
val (leftRenew, leftChanged) = renewDuplicatedRelations(existingRelations, left)
val (rightRenew, rightChanged) = renewDuplicatedRelations(existingRelations, right)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(leftRenew,
rightRenew), plan, leftChanged || rightChanged)

var newFlatMap = newPlan.asInstanceOf[FlatMapCoGroupsInPandas]
val planWrapper = RelationWrapper(newFlatMap.getClass,
newFlatMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newFlatMap = newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))
newFlatMap.copyTagsFrom(plan)
(newFlatMap, true)
} else {
existingRelations.add(planWrapper)
(newFlatMap, planChanged)
}

case m @ MapInPandas(_, _, child, _) if m.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newMap = newPlan.asInstanceOf[MapInPandas]
val planWrapper = RelationWrapper(newMap.getClass,
newMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newMap = newMap.copy(output = newMap.output.map(_.newInstance()))
newMap.copyTagsFrom(plan)
(newMap, true)
} else {
existingRelations.add(planWrapper)
(newMap, planChanged)
}

case p @ PythonMapInArrow(_, _, child, _) if p.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newMap = newPlan.asInstanceOf[PythonMapInArrow]
val planWrapper = RelationWrapper(newMap.getClass,
newMap.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newMap = newMap.copy(output = newMap.output.map(_.newInstance()))
newMap.copyTagsFrom(plan)
(newMap, true)
} else {
existingRelations.add(planWrapper)
(newMap, planChanged)
}

case a @ AttachDistributedSequence(_, child) if a.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newAttach = newPlan.asInstanceOf[AttachDistributedSequence]
val planWrapper = RelationWrapper(newAttach.getClass,
newAttach.producedAttributes.map(_.exprId.id).toSeq)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newAttach = newAttach.copy(sequenceAttr = newAttach.producedAttributes.map(_.newInstance())
.head)
newAttach.copyTagsFrom(plan)
(newAttach, true)
} else {
existingRelations.add(planWrapper)
(newAttach, planChanged)
}

case g @ Generate(_, _, _, _, _, child) if g.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newGenerate = newPlan.asInstanceOf[Generate]
val planWrapper = RelationWrapper(newGenerate.getClass,
newGenerate.generatorOutput.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newGenerate = newGenerate.copy(generatorOutput = newGenerate.generatorOutput.map(
_.newInstance()))
newGenerate.copyTagsFrom(plan)
(newGenerate, true)
} else {
existingRelations.add(planWrapper)
(newGenerate, planChanged)
}

case e @ Expand(_, _, child) if e.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newExpand = newPlan.asInstanceOf[Expand]
val planWrapper = RelationWrapper(newExpand.getClass,
newExpand.producedAttributes.map(_.exprId.id).toSeq)
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newExpand = newExpand.copy(output = newExpand.output.map(_.newInstance()))
newExpand.copyTagsFrom(plan)
(newExpand, true)
} else {
existingRelations.add(planWrapper)
(newExpand, planChanged)
}

case w @ Window(_, _, _, child) if w.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newWindow = newPlan.asInstanceOf[Window]
val planWrapper = RelationWrapper(newWindow.getClass,
newWindow.windowExpressions.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newWindow = newWindow.copy(windowExpressions =
newWindow.windowExpressions.map(_.newInstance()))
newWindow.copyTagsFrom(plan)
(newWindow, true)
} else {
existingRelations.add(planWrapper)
(newWindow, planChanged)
}

case s @ ScriptTransformation(_, _, child, _) if s.resolved =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, child)
val (newPlan, planChanged) = getNewPlanWithNewChildren(existingRelations, Array(renewed),
plan, changed)
var newScript = newPlan.asInstanceOf[ScriptTransformation]
val planWrapper = RelationWrapper(newScript.getClass,
newScript.output.map(_.exprId.id))
if (existDuplicatedExprId(existingRelations, planWrapper)) {
newScript = newScript.copy(output = newScript.output.map(_.newInstance()))
newScript.copyTagsFrom(plan)
(newScript, true)
} else {
existingRelations.add(planWrapper)
(newScript, planChanged)
}

case plan: LogicalPlan =>
var planChanged = false
val newPlan = if (plan.children.nonEmpty) {
Expand All @@ -117,37 +311,49 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
}
}

val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
}
val (newPlan, changed) = getNewPlanWithNewChildren(existingRelations, newChildren.toArray,
plan, planChanged)
planChanged |= changed
newPlan
} else {
plan
}
(newPlan, planChanged)
}

if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(
plan
.children
.flatMap(_.output).zip(newChildren.flatMap(_.output))
.filter { case (a1, a2) => a1.exprId != a2.exprId }
)
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
}
} else {
planWithNewSubquery.withNewChildren(newChildren.toSeq)
}
} else {
private def getNewPlanWithNewChildren(
existingRelations: mutable.HashSet[RelationWrapper],
newChildren: Array[LogicalPlan],
plan: LogicalPlan, changed: Boolean): (LogicalPlan, Boolean) = {
var planChanged = changed
val planWithNewSubquery = plan.transformExpressions {
case subquery: SubqueryExpression =>
val (renewed, changed) = renewDuplicatedRelations(existingRelations, subquery.plan)
if (changed) planChanged = true
subquery.withNewPlan(renewed)
}

val newPlan = if (planChanged) {
if (planWithNewSubquery.childrenResolved) {
val planWithNewChildren = planWithNewSubquery.withNewChildren(newChildren.toSeq)
val attrMap = AttributeMap(
plan
.children
.flatMap(_.output).zip(newChildren.flatMap(_.output))
.filter { case (a1, a2) => a1.exprId != a2.exprId }
)
if (attrMap.isEmpty) {
planWithNewChildren
} else {
planWithNewChildren.rewriteAttrs(attrMap)
}
} else {
plan
planWithNewSubquery.withNewChildren(newChildren.toSeq)
}
(newPlan, planChanged)
} else {
plan
}
(newPlan, planChanged)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1977,17 +1977,16 @@ Union false, false
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Project [c1#x AS c1#x, c2#x AS c2#x, c1#x AS c1#x, c2#x AS c2#x]
+- Project [c1#x, c2#x, c1#x, c2#x]
+- Join Inner
:- SubqueryAlias spark_catalog.default.t1
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t4
+- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
+- Project [c1#x, c2#x, c1#x, c2#x]
+- Join Inner
:- SubqueryAlias spark_catalog.default.t1
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t4
+- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down Expand Up @@ -2030,27 +2029,26 @@ Union false, false
: +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- Project [c1#x AS c1#x, c2#x AS c2#x, c2#x AS c2#x]
+- Project [c1#x, c2#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Union false, false
: :- Project [c2#x]
: : +- Filter (outer(c1#x) <= c1#x)
: : +- SubqueryAlias spark_catalog.default.t1
: : +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- Project [c2#x]
: +- Filter (c1#x < outer(c1#x))
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t2
+- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
+- Project [c1#x, c2#x, c2#x]
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
: +- SubqueryAlias __auto_generated_subquery_name
: +- Union false, false
: :- Project [c2#x]
: : +- Filter (outer(c1#x) <= c1#x)
: : +- SubqueryAlias spark_catalog.default.t1
: : +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- Project [c2#x]
: +- Filter (c1#x < outer(c1#x))
: +- SubqueryAlias spark_catalog.default.t4
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t2
+- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,3 +1057,21 @@ Project [c1#xL, c2#xL]
: +- Range (1, 2, step=1, splits=None)
+- SubqueryAlias t1
+- Range (1, 3, step=1, splits=None)


-- !query
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
-- !query analysis
Project [c1#x, c2#x, scalar-subquery#x [c1#x] AS scalarsubquery(c1)#xL]
: +- Filter (cnt#xL = cast(0 as bigint))
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(c1#x) = c1#x)
: +- SubqueryAlias t2
: +- SubqueryAlias t1
: +- View (`t1`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias t1
+- View (`t1`, [c1#x,c2#x])
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+- LocalRelation [col1#x, col2#x]
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,6 @@ select * from (
where t1.id = t2.id ) c2
from range (1, 3) t1 ) t
where t.c2 is not null;

-- SPARK-43838: Subquery on single table with having clause
SELECT c1, c2, (SELECT count(*) cnt FROM t1 t2 WHERE t1.c1 = t2.c1 HAVING cnt = 0) FROM t1
Loading

0 comments on commit 0c12ef9

Please sign in to comment.