Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32945][SQL] Avoid collapsing projects if reaching max allowed common exprs #29950

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -732,10 +732,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
* `GlobalLimit(LocalLimit)` pattern is also considered.
*/
object CollapseProject extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject

if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) ||
getLargestNumOfCommonOutput(p1.projectList, p2.projectList) > maxCommonExprs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation?

p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
Expand Down Expand Up @@ -766,6 +768,23 @@ object CollapseProject extends Rule[LogicalPlan] {
})
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could extend to other cases like case p @ Project(_, agg: Aggregate), but leave it untouched for now.


// Counts for the largest times common outputs from lower operator are used in upper operators.
private def getLargestNumOfCommonOutput(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cannot share the code between this and moreThanMaxAllowedCommonOutput?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two places looks similar however the parameters are slightly different. We can make them share same code, but the code lines are just few and refactoring needs more change, so seems not worth to me.

upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Int = {
val aliases = collectAliases(lower)
val exprMap = mutable.HashMap.empty[Attribute, Int]

upper.foreach(_.collect {
case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1)
})

if (exprMap.size > 0) {
exprMap.maxBy(_._2)._2
} else {
0
}
}

private def haveCommonNonDeterministicOutput(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
// Create a map of Aliases to their values from the lower projection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.planning

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -108,6 +110,8 @@ object ScanOperation extends OperationHelper with PredicateHelper {
type ScanReturnType = Option[(Option[Seq[NamedExpression]],
Seq[Expression], LogicalPlan, AttributeMap[Expression])]

val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject

def unapply(plan: LogicalPlan): Option[ReturnType] = {
collectProjectsAndFilters(plan) match {
case Some((fields, filters, child, _)) =>
Expand All @@ -124,14 +128,34 @@ object ScanOperation extends OperationHelper with PredicateHelper {
}.exists(!_.deterministic))
}

def moreThanMaxAllowedCommonOutput(
expr: Seq[NamedExpression],
Copy link
Member

@dongjoon-hyun dongjoon-hyun Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation? It seems that there is one more space here.

aliases: AttributeMap[Expression]): Boolean = {
val exprMap = mutable.HashMap.empty[Attribute, Int]

expr.foreach(_.collect {
case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1)
})

val commonOutputs = if (exprMap.size > 0) {
exprMap.maxBy(_._2)._2
} else {
0
}

commonOutputs > maxCommonExprs
}

private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = {
plan match {
case Project(fields, child) =>
collectProjectsAndFilters(child) match {
case Some((_, filters, other, aliases)) =>
// Follow CollapseProject and only keep going if the collected Projects
// do not have common non-deterministic expressions.
if (!hasCommonNonDeterministic(fields, aliases)) {
// do not have common non-deterministic expressions, or do not have equal to/more than
// maximum allowed common outputs.
if (!hasCommonNonDeterministic(fields, aliases)
|| !moreThanMaxAllowedCommonOutput(fields, aliases)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, you may want to move || into line 157.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. thanks.

val substitutedFields =
fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,19 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT =
buildConf("spark.sql.optimizer.maxCommonExprsInCollapseProject")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set this value to 1, all the existing tests can pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess not. We might have at lease few common expressions in collapsed projection. If set to 1, any duplicated expression is not allowed.

.doc("An integer number indicates the maximum allowed number of a common expression " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a common expression -> common input attributes?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the maximum allowed number of a common expression can be collapsed into upper Project from lower Project ... => the maximum allowed number of common input attributes when collapsing adjacent Projects ...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly the same, but I revised the doc.

"can be collapsed into upper Project from lower Project by optimizer rule " +
"`CollapseProject`. Normally `CollapseProject` will collapse adjacent Project " +
Copy link
Member

@maropu maropu Oct 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Just a comment) Even if we set spark.sql.optimizer.excludedRules to CollapseProject, it seems like Spark still respects this value in ScanOperation? That behaviour might be okay, but it looks a bit weird to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but currently if we exclude CollapseProject, ScanOperation will work and collapse projections. Maybe update this doc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I see. Yea, updating the doc sounds nice to me.

"and merge expressions. But in some edge cases, expensive expressions might be " +
"duplicated many times in merged Project by this optimization. This config sets " +
"a maximum number. Once an expression is duplicated more than this number " +
"if merging two Project, Spark SQL will skip the merging.")
.version("3.1.0")
.intConf
.createWithDefault(20)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. Is there a reason to choose 20?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, just decide a number that seems bad for repeating an expression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, can we introduce this configuration with Int.MaxValue in 3.1.0 first? We can reduce it later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. It is safer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS =
buildConf("spark.sql.decimalOperations.allowPrecisionLoss")
.internal()
Expand Down Expand Up @@ -3289,6 +3302,8 @@ class SQLConf extends Serializable with Logging {

def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)

def maxCommonExprsInCollapseProject: Int = getConf(MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT)

def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Rand}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{MetadataBuilder, StructType}

class CollapseProjectSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
Expand Down Expand Up @@ -170,4 +171,34 @@ class CollapseProjectSuite extends PlanTest {
val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze
comparePlans(optimized, expected)
}

test("SPARK-32945: avoid collapsing projects if reaching max allowed common exprs") {
val options = Map.empty[String, String]
val schema = StructType.fromDDL("a int, b int, c string, d long")

Seq("1", "2", "3", "4").foreach { maxCommonExprs =>
withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) {
// If we collapse two Projects, `JsonToStructs` will be repeated three times.
val relation = LocalRelation('json.string)
val query = relation.select(
JsonToStructs(schema, options, 'json).as("struct"))
Copy link
Member

@dongjoon-hyun dongjoon-hyun Nov 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation? Maybe, the following is better?

- val query1 = relation.select(
-   JsonToStructs(schema, options, 'json).as("struct"))
-   .select(
+ val query1 = relation.select(JsonToStructs(schema, options, 'json).as("struct"))
+   .select(

.select(
GetStructField('struct, 0).as("a"),
GetStructField('struct, 1).as("b"),
GetStructField('struct, 2).as("c")).analyze
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using the dataset API, then it would be very common to chain withColumn calls:

dataset
    .withColumn("json", ...)
    .withColumn("a", col("json").getField("a"))
    .withColumn("b", col("json").getField("b"))
    .withColumn("c", col("json").getField("c"))

In that case the query should look more like this:

        val query = relation
          .select('json, JsonToStructs(schema, options, 'json).as("struct"))
          .select('json, 'struct, GetStructField('struct, 0).as("a"))
          .select('json, 'struct, 'a, GetStructField('struct, 1).as("b"))
          .select('json, 'struct, 'a, 'b, GetStructField('struct, 2).as("c"))
          .analyze

The CollapseProject rule uses transformUp. It seems that in that case we do not get the expected results from this optimization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems can be fixed by using transformDown instead? Seems to me CollapseProject is not necessarily to use transformUp if I don't miss anything. cc @cloud-fan @maropu

Copy link
Contributor

@tanelk tanelk Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a chain of projects: P1(P2(P3(P4(...)))), then using transformDown will firstly merge P1 and P2 into P12 and then it will go to its child P3 and merge it with P4 into P34. Only on the second iteration it will merge all 4 of these.

In this case we want to merge P123 and then see, that we can't merge with P4 because we would exceed maxCommonExprsInCollapseProject.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, that correct way would be using transformDown in a similar manner to recursiveRemoveSort in #21072.
So basically when you hit the first Project, then you collect all consecutive Projects until you hit the maxCommonExprsInCollapseProject limit and merge them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, it sounds fine, too. Rather, it seems a top-down transformation can collapse projects in one shot just like RemoveRedundantProjects?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we need to change to transformDown and take a recursive approach like RemoveRedundantProjects and recursiveRemoveSort for collapsing Project.

val optimized = Optimize.execute(query)

if (maxCommonExprs.toInt < 3) {
val expected = query
comparePlans(optimized, expected)
} else {
val expected = relation.select(
GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"),
GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"),
GetStructField(JsonToStructs(schema, options, 'json), 2).as("c")).analyze
comparePlans(optimized, expected)
}
}
}
}
}