Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31973][SQL] Skip partial aggregates if grouping keys have high cardinality #28804

Closed
wants to merge 33 commits into from

Conversation

karuppayya
Copy link
Contributor

@karuppayya karuppayya commented Jun 12, 2020

What changes were proposed in this pull request?

In case of HashAggregation, a partial aggregation(update) is done followed by final aggregation(merge)

During partial aggregation we sort and spill to disk every-time fby, when the fast Map(when enabled) and UnsafeFixedWidthAggregationMap gets exhausted

When the cardinality of grouping column is close to the total number of records being processed, the sorting of data spilling to disk is not required, since it is kind of no-op and we can directly use rows in Final aggregation.

When the user is aware of nature of data, currently he has no control over disabling this sort, spill operation.

This is similar to following issues in Hive:
https://issues.apache.org/jira/browse/HIVE-223
https://issues.apache.org/jira/browse/HIVE-291

In this PR, the ability to disable sort/spill during partial aggregation is added

Benchmark

spark.executor.memory = 12G

Init code

// init code
case class Data(name: String, value1: String, value2: String, value3: Long, random: Int)
val numRecords = Seq(60000000)
val tblName = "tbl"

Generate data

// init code
case class Data(name: String, value1: String, value2: String, value3: Long, random: Int)
val numRecords = Seq(30000000, 60000000)

val basePath = "s3://qubole-spar/karuppayya/SPAR-4477/benchmark/"
val rand = scala.util.Random
// write
numRecords.foreach {
  recordCount =>
    val dataLocation = s"$basePath/$recordCount"
    val dataDF = spark.range(recordCount).map {
      x =>
        if (x < 10) Data(s"name1", s"value1", s"value1", 10, rand.nextInt(100))
        else Data(s"name$x", s"value$x", s"value$x", 1, rand.nextInt(100))
    }
    // creating data to be processed by on task(aslo gzip-ing to ensure spark doesnt
    // create multiple splits )
    val randomDF = dataDF.orderBy("random")
      randomDF.drop("random").repartition(1)
      .write
      .mode("overwrite")
      .option("compression", "gzip")
      .parquet(dataLocation)
}

query

val query =
  s"""
    |SELECT name, value1, value2, SUM(value3) s
    |FROM $tblName
    |GROUP BY name, value1, value2
    |"""

Benchmark code

  .add(StructField("name", StringType))
  .add(StructField("value1", StringType))
  .add(StructField("value2", StringType))
  .add(StructField("value3", LongType))
val query =
  """
    |SELECT name, value1, value2, SUM(value3) s
    |FROM tbl
    |GROUP BY name, value1, value2
    |"""

case class Metric(recordCount: Long, partialAggregateEnabled: Boolean, timeTaken: Long)
val metrics = Seq(true, false).flatMap {
  enabled =>
    sql(s"set spark.sql.aggregate.partialaggregate.skip.enabled=$enabled").collect
    numRecords.map {
      recordCount =>
        import java.util.concurrent.TimeUnit.NANOSECONDS
        val dataLocation = s"$basePath/$recordCount"
        spark.read
          .option("inferTimestamp", "false")
          .schema(userSpecifiedSchema)
          .json(dataLocation)
          .createOrReplaceTempView("tbl")
        val start = System.nanoTime()
        spark.sql(query).filter("s > 10").collect
        val end = System.nanoTime()
        val diff = end - start
        Metric(recordCount, enabled, NANOSECONDS.toMillis(diff))
    }
}

Results

val df = metrics.toDF
df.createOrReplaceTempView("a")
val df = sql("select * from a order by recordcount desc, partialAggregateEnabled")
df.show()
scala> df.show
+-----------+-----------------------+---------+
|recordCount|partialAggregateEnabled|timeTaken|
+-----------+-----------------------+---------+
|   90000000|                  false|   593844|
|   90000000|                   true|   412958|
|   60000000|                  false|   377054|
|   60000000|                   true|   276363|

Percent improvement:

90000000 → 30.46%, 60000000 → 26.70%

Why are the changes needed?

This improvement can improve the performance of queries

Does this PR introduce any user-facing change?

No

How was this patch tested?

This patch was tested manually

@karuppayya karuppayya changed the title SPARK-31973: Add ability to disable Sort,Spill in Partial aggregation [SPARK-31973][SQL] Add ability to disable Sort,Spill in Partial aggregation Jun 12, 2020
@@ -165,6 +166,26 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
}
}

test("SPARK-: Avoid spill in partial aggregation " +
"when spark.sql.aggregate.spill.partialaggregate.disabled is set") {
withSQLConf((SQLConf.SPILL_PARTIAL_AGGREGATE_DISABLED.key, "true"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not sufficient since the rows still get added to UnsafeFixedWidthAggregationMap , need to figure out a way to avoid adding elements to UnsafeFixedWidthAggregationMap. Please advise @cloud-fan @gatorsmile @maropu

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show us performance numbers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu I figured out few more improvements taht can be made to the generated code, I will test them and also addd the benchmark number.
Adding WIP tag to the title.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu
I have added the benchmark to the description.
I have also figured out a way to add UT for this.
@maropu @cloud-fan @gatorsmile Can you please help review the change

@karuppayya karuppayya changed the title [SPARK-31973][SQL] Add ability to disable Sort,Spill in Partial aggregation [SPARK-31973][SQL][WIP] Add ability to disable Sort,Spill in Partial aggregation Jun 12, 2020
@karuppayya karuppayya changed the title [SPARK-31973][SQL][WIP] Add ability to disable Sort,Spill in Partial aggregation [SPARK-31973][SQL] Add ability to disable Sort,Spill in Partial aggregation Jun 12, 2020
@karuppayya karuppayya changed the title [SPARK-31973][SQL] Add ability to disable Sort,Spill in Partial aggregation [SPARK-31973][SQL][WIP] Add ability to disable Sort,Spill in Partial aggregation Jun 13, 2020
@karuppayya karuppayya requested a review from maropu June 17, 2020 19:58
@karuppayya karuppayya changed the title [SPARK-31973][SQL][WIP] Add ability to disable Sort,Spill in Partial aggregation [SPARK-31973][SQL] Add ability to disable Sort,Spill in Partial aggregation Jun 17, 2020
@maropu
Copy link
Member

maropu commented Jun 18, 2020

ok to test

@@ -2173,6 +2173,13 @@ object SQLConf {
.checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].")
.createWithDefault(16)

val SPILL_PARTIAL_AGGREGATE_DISABLED =
buildConf("spark.sql.aggregate.spill.partialaggregate.disabled")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disabled -> enabled to follow the other config naming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have renamed the config. Can you please check

@SparkQA
Copy link

SparkQA commented Jun 18, 2020

Test build #124189 has finished for PR 28804 at commit 5f05aa7.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@karuppayya
Copy link
Contributor Author

@maropu @cloud-fan Can you please add me to the whitelist to trigger the tests?

@SparkQA
Copy link

SparkQA commented Jun 18, 2020

Test build #124199 has finished for PR 28804 at commit 2b3704b.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@maropu
Copy link
Member

maropu commented Jun 18, 2020

add to whitelist

@SparkQA
Copy link

SparkQA commented Jun 18, 2020

Test build #124209 has finished for PR 28804 at commit 2b3704b.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jun 19, 2020

Test build #124294 has finished for PR 28804 at commit 9a10261.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jun 19, 2020

Test build #124295 has finished for PR 28804 at commit 1c0399a.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jun 19, 2020

Test build #124296 has finished for PR 28804 at commit 73de4c8.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 7, 2020

Test build #127214 has finished for PR 28804 at commit ceaa4e5.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 7, 2020

Test build #127216 has finished for PR 28804 at commit 2ae5525.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Aug 8, 2020

Test build #127218 has finished for PR 28804 at commit 0a186f0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

|// output the result
|$outputFromFastHashMap
|$outputFromRegularHashMap
""".stripMargin
}

override def needStopCheck: Boolean = skipPartialAggregateEnabled
Copy link
Member

@maropu maropu Aug 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If skipPartialAggregateEnabled = true but #rows/cardinality don't go over the threshold, partial aggregates are not skipped. Even in that case, we set true to needStopCheck?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the ration of cardinality to numRows,
doesnot go beyond threshold - it implies that the optimization has not kicked in yet. In which case org.apache.spark.sql.execution.CodegenSupport#shouldStopCheckCode, returns false. And continues with the iterating over remaining items of the iterator.
goes beyond threshold - We add the item(since the addition to Map is skipped) to the org.apache.spark.sql.execution.BufferedRowIterator#currentRows, which gets consumed by the parent.

Since it is inexpensive operation and has been used at many places in HashAggregateExec and didnt see any performance penalties, this approached seemed ok to me .

Please let me know if you have any other suggestions.
Generated code:

private void agg_doAggregateWithKeys_0() throws java.io.IOException {
/* 318 */     while ( localtablescan_input_0.hasNext()) {
/* 319 */       InternalRow localtablescan_row_0 = (InternalRow) localtablescan_input_0.next();
/* 320 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1);
/* 321 */       boolean localtablescan_isNull_0 = localtablescan_row_0.isNullAt(0);
/* 322 */       UTF8String localtablescan_value_0 = localtablescan_isNull_0 ?
/* 323 */       null : (localtablescan_row_0.getUTF8String(0));
/* 324 */       int localtablescan_value_1 = localtablescan_row_0.getInt(1);
/* 325 */
/* 326 */       agg_doConsume_0(localtablescan_value_0, localtablescan_isNull_0, localtablescan_value_1);
/* 327 */       if (shouldStop()) return; // code added as part of needStopCheck
/* 328 */     }
/* 329 */
/* 330 */     agg_childrenConsumed_0 = true;
/* 331 */
/* 332 */     agg_fastHashMapIter_0 = agg_fastHashMap_0.rowIterator();
/* 333 */     agg_mapIter_0 = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap_0, agg_sorter_0, ((org.apache.spark.sql.execution.metric.SQLMetric) references[3] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* avgHashProbe */));
/* 334 */
/* 335 */   }

@cloud-fan
Copy link
Contributor

Can we have a short description of the approach in the PR description? Seems like we are making hash aggregate adaptive to do pass-through if it can't reduce data size in the first n rows.

@SparkQA
Copy link

SparkQA commented Aug 11, 2020

Test build #127340 has finished for PR 28804 at commit 11572a1.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@maropu
Copy link
Member

maropu commented Aug 13, 2020

btw, could this optimization be implemented on the adaptive execution framework (AdaptiveSparkPlanExec)? In the initial discussion (https://github.com/apache/spark/pull/28804/files#r447158417), it was pointed out that accurate statistics could not be collected. But, I think we might be able to collect the stats based on the framework. Yea, as we know, we need to look for a light-weight way to compute cardinality on shuffle output. If we find it, I think we can simply drop partial aggregate for high cardinality cases.

Have you already considered this approach? What I'm worried about now is that the current implementation makes the code complicated and it is limited to hash aggregates w/codegen only.

@cloud-fan
Copy link
Contributor

I don't think AQE can help here. This is partial aggregate and usually there won't be a shuffle right before the partial agg.

@maropu
Copy link
Member

maropu commented Aug 13, 2020

I don't think AQE can help here. This is partial aggregate and usually there won't be a shuffle right before the partial agg.

Hm, I see. Even so, BasicStatsPlanVisitor cannot propatate somewhat accurate input stats (of shuffle output) into partial aggreates?

@karuppayya
Copy link
Contributor Author

@maropu The stats(specifically number of records from aggregation map after a threshold) that we are looking for is available only at the operator level at runtime.

@maropu
Copy link
Member

maropu commented Aug 14, 2020

@maropu The stats(specifically number of records from aggregation map after a threshold) that we are looking for is available only at the operator level at runtime.

I pointed out not the current approach, but the previous one: https://github.com/apache/spark/pull/28804/files#r446720097

@cloud-fan
Copy link
Contributor

AQE doesn't provide column stats, and column stats propagation can be incorrect if we have many operators.

IIUC the current approach is: sample the first 100000 rows, if they can't reduce data by half (which means one key has 2 values by average), then we skip the partial aggregate.

This sounds reasonable, but it's hard to tell how to pick the config values. @karuppayya do you have any experience of using it in practice?

@karuppayya
Copy link
Contributor Author

@cloud-fan
We observed this behaviour(partial aggregation not helping) in one of our customers.
Initially, I had disabled the partial aggregation completely by making the Aggregate mode to org.apache.spark.sql.catalyst.expressions.aggregate.Complete
But later found the Hive's optimization for handling this scenario.
I have used the Hive's heuristic(with default for minRows of 100000 to be sampled) in this PR.

@SparkQA
Copy link

SparkQA commented Oct 19, 2020

Test build #129988 has finished for PR 28804 at commit 11572a1.

  • This patch fails due to an unknown error code, -9.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 13, 2020

Test build #132721 has finished for PR 28804 at commit 11572a1.

  • This patch fails Spark unit tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@c21
Copy link
Contributor

c21 commented Mar 18, 2021

Any update for the PR? Thanks.

@c21
Copy link
Contributor

c21 commented Jun 3, 2021

@karuppayya - we did a similar change internally and rolled out to production already in facebook. We made some change on top of this (e.g. only skip partial aggregate when the map needs spill), and fixed several bugs. Do you mind if we submit a separate PR (list you as co-author) and help move the feature forward? Thanks.

@karuppayya
Copy link
Contributor Author

@c21 Please go ahead

@github-actions
Copy link

We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable.
If you'd like to revive this PR, please reopen it and ask a committer to remove the Stale tag!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants