Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31973][SQL] Skip partial aggregates if grouping keys have high cardinality #28804

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
db8a62d
Fix: Init commit
karuppayya Jun 12, 2020
9a59925
Fix: Fix UT name
karuppayya Jun 12, 2020
feacdcf
Fix: Fix codegen
karuppayya Jun 12, 2020
ab98ea4
Revert "Fix: Fix codegen"
karuppayya Jun 12, 2020
2e102d1
Fix: Fix codegen logic
karuppayya Jun 12, 2020
5fa601b
Fix: Fix codegen logic
karuppayya Jun 12, 2020
220eaed
Fix: Fix codegen logic
karuppayya Jun 17, 2020
452b632
Fix: clean up
karuppayya Jun 17, 2020
68dd5a3
Fix: remove partialmerge
karuppayya Jun 17, 2020
692fd1b
Fix: fix typo, remove whitelines
karuppayya Jun 17, 2020
f1b6ac1
Fix: Fix UT attempt
karuppayya Jun 18, 2020
05c891f
Fix: Address review comments
karuppayya Jun 18, 2020
dd3c56a
Fix: UT fixes, refactoring
karuppayya Jun 19, 2020
cb8b922
Fix: fix indent
karuppayya Jun 19, 2020
7952aa7
UT: Add more test
karuppayya Jun 19, 2020
56c95e2
Fix UT attempt
karuppayya Jun 19, 2020
43237ba
Enabling the conf to runn all tests with the feature
karuppayya Jun 20, 2020
99c1d22
Unit test fix attempt
karuppayya Jun 24, 2020
d2873a3
UT fix attmpt
karuppayya Jun 24, 2020
afc2903
Ut fix attempt
karuppayya Jun 25, 2020
7766401
Add heuristic
karuppayya Jul 3, 2020
75125d9
Fix: Include missing change, remove unnecessary changes, handle comments
karuppayya Jul 6, 2020
3ca81ae
Refactor: avoid additional code on reducer, fix tests,
karuppayya Jul 8, 2020
8850777
gst
karuppayya Jul 8, 2020
26a2fd6
Address review comments
karuppayya Jul 15, 2020
c088816
Address review commenst
karuppayya Aug 6, 2020
c49f106
Fix forward reference
karuppayya Aug 6, 2020
69f1d71
UT fixes, address review comments
karuppayya Aug 7, 2020
c9a415d
Address review copmments
karuppayya Aug 7, 2020
ceaa4e5
Fix style check
karuppayya Aug 7, 2020
2ae5525
Fix UT
karuppayya Aug 7, 2020
0a186f0
UT fix
karuppayya Aug 7, 2020
11572a1
Address review comments
karuppayya Aug 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2196,6 +2196,25 @@ object SQLConf {
.checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].")
.createWithDefault(16)

val SKIP_PARTIAL_AGGREGATE_ENABLED =
buildConf("spark.sql.aggregate.partialaggregate.skip.enabled")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this only works for hash aggregate but not the sort based aggregate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I beleive this heuristic can be applied for sort based aggregation as well. I started with Hash based aggregate, I will create a new PR for sort based aggregation.

.internal()
.doc("Avoid sort/spill to disk during partial aggregation")
.booleanConf
.createWithDefault(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a threadhold + column stats instead of this boolean config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didnt get the threshold part. Can you pleas elaborate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That meant a ratio of a distinct row count and total row count in group-by key column stats. For example, if a number distinctCount / rowCount is close to 1.0, you apply the optimization; otherwise, you don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maropu for explaining, I will make this change

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu This is very useful suggestion. One issue is columns stats are rarely computed. We came across this work in HIVE https://issues.apache.org/jira/browse/HIVE-291. They turn off map side aggregate (i.e., partial aggregate will be pass through) in Physical operator (i.e., Group-By operator) if map-side aggregation reduce the entries by at least half and they look at 100000 rows to do that (ref: patch https://issues.apache.org/jira/secure/attachment/12400257/291.1.txt). Should we do something similar in HashAggregateExec here ? Any thoughts on this ?

Copy link
Member

@maropu maropu Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They turn off map side aggregate (i.e., partial aggregate will be pass through) in Physical operator (i.e., Group-By operator) if map-side aggregation reduce the entries by at least half and they look at 100000 rows to do that

I think whether that approach improves performance depends on IO performance, but the idea looks interesting to me. WDYT? @cloud-fan


val SKIP_PARTIAL_AGGREGATE_THRESHOLD =
buildConf("spark.sql.aggregate.partialaggregate.skip.threshold")
.internal()
.longConf
.createWithDefault(100000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it must be data-dependent. If I have 10^10 records, and partial agg cuts it down to 10^8 (1% of the original inputs), it's still worth to do partial agg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan we skip partial aggregartion only when the aggragation was not able to cut down records by 50%(define by spark.sql.aggregate.partialaggregate.skip.ratio). In this case it will not kick in.


val SKIP_PARTIAL_AGGREGATE_RATIO =
buildConf("spark.sql.aggregate.partialaggregate.skip.ratio")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you describe more in .doc for the two params? When reading them, I couldn't understand how-to-use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. Can you please check if its explanotory?

.internal()
.doubleConf
.createWithDefault(0.5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need the two params for this optimiation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could you check performance numbers by varying the params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu I have borrowed this heuristic from Hive. We can merge them into one. Any suggestions here?


val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec")
.doc("Compression codec used in writing of AVRO files. Supported codecs: " +
"uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.")
Expand Down Expand Up @@ -2922,6 +2941,12 @@ class SQLConf extends Serializable with Logging {

def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT)

def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED)

def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_THRESHOLD)

def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_RATIO)

def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED)

def uiExplainMode: String = getConf(UI_EXPLAIN_MODE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,4 +353,8 @@ object AggUtils {

finalAndCompleteAggregate :: Nil
}

def areAggExpressionsPartial(modes: Seq[AggregateMode]): Boolean = {
modes.nonEmpty && modes.forall(_ == Partial)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We canno apply this optimization if the empty case? e.g.,

scala> sql("select k from t group by k").explain()
== Physical Plan ==
*(2) HashAggregate(keys=[k#29], functions=[])
+- Exchange hashpartitioning(k#29, 200), true, [id=#47]
   +- *(1) HashAggregate(keys=[k#29], functions=[])
      +- *(1) ColumnarToRow
         +- FileScan parquet default.t[k#29] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/t], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<k:int>

Copy link
Contributor Author

@karuppayya karuppayya Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the reducer side also does not have any aggregate function and we might end up not aggregating the data

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ case class HashAggregateExec(

require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))

override def needStopCheck: Boolean = skipPartialAggregate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to add stopCheck to child operators after this change?


override lazy val allAttributes: AttributeSeq =
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
Expand All @@ -72,6 +74,8 @@ case class HashAggregateExec(
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"),
"partialAggSkipped" -> SQLMetrics.createMetric(sparkContext,
"number of skipped records for partial aggregates"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this metric is only meaningful for aggregates with a partial mode, so could we show it only in the mode?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partialAggSkipped -> numAggSkippedRows?

"avgHashProbe" ->
SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters"))

Expand Down Expand Up @@ -409,6 +413,12 @@ case class HashAggregateExec(
private var fastHashMapTerm: String = _
private var isFastHashMapEnabled: Boolean = false

private var avoidSpillInPartialAggregateTerm: String = _
private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate &&
AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check find(_.isInstanceOf[ExpandExec]).isEmpty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required to avoid this optimization for Query with more than one distinct.
org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates takes cares of rewriting aggregates with more than one distinct.
The rule assumes that map side aggregation has taken care of performing distinct operation.
With my change this will result in wrong results.
For example:
For the first example given as part of comments in the rule

 * First example: query without filter clauses (in scala):
 * {{{
 *   val data = Seq(
 *     ("a", "ca1", "cb1", 10),
 *     ("a", "ca1", "cb2", 5),
 *     ("b", "ca1", "cb1", 13))
 *     .toDF("key", "cat1", "cat2", "value")
 *   data.createOrReplaceTempView("data")
 *
 *   val agg = data.groupBy($"key")
 *     .agg(
 *       countDistinct($"cat1").as("cat1_cnt"),
 *       countDistinct($"cat2").as("cat2_cnt"),
 *       sum($"value").as("total"))
 * }}}
 *
 * This translates to the following (pseudo) logical plan:
 * {{{
 * Aggregate(
 *    key = ['key]
 *    functions = [COUNT(DISTINCT 'cat1),
 *                 COUNT(DISTINCT 'cat2),
 *                 sum('value)]
 *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
 *   LocalTableScan [...]
 * }}}
 *
 * This rule rewrites this logical plan to the following (pseudo) logical plan:
 * {{{
 * Aggregate(
 *    key = ['key]
 *    functions = [count(if (('gid = 1)) 'cat1 else null),
 *                 count(if (('gid = 2)) 'cat2 else null),
 *                 first(if (('gid = 0)) 'total else null) ignore nulls]
 *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
 *   Aggregate(
 *      key = ['key, 'cat1, 'cat2, 'gid]
 *      functions = [sum('value)]
 *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
 *     Expand(
 *        projections = [('key, null, null, 0, cast('value as bigint)),
 *                       ('key, 'cat1, null, 1, null),
 *                       ('key, null, 'cat2, 2, null)]
 *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
 *       LocalTableScan [...]
 * }}}
 *

Say the following are the two records in the dataset

rec 1: (“key1“, “cat1“, “cat1“, 1)
rec 2: (“key1“, “cat1“, “cat1“, 1)

With my change
After expand:

(“key1“, “null“, “null“, 0, 1)

(“key1“, “cat1“, “null“, 1, null)

(“key1“, “null“, “cat2“, 2, null)

(“key1“, “null“, “null“, 0, 1)

(“key1“, “cat1“, “null“, 1, null)

(“key1“, “null“, “cat1“, 2, null)

After partial aggregation

(“key1“, “null“, “null“, 0, 1)

(“key1“, “cat1“, “null“, 1, null)

(“key1“, “null“, “cat2“, 2, null)

(“key2“, “null“, “null“, 0, 1)

(“key2“, “cat2“, “null“, 1, null)

(“key2“, “null“, “cat2“, 2, null)

Reducer side aggregation result: (key1, 2, 2, 2)
But the correct answer is: (key1, 1, 1, 2)

Hence checking for the presence of expand node to avoid this skipping partial aggregation

Copy link
Member

@maropu maropu Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, I see... But, I think the current approach looks dangerous because we might add a new plan having the same assumption in future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True.. Any suggestions on handling this in a better way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no smart idea for that now. Hive doesn't have that kind of plans?

private var rowCountTerm: String = _
private var outputFunc: String = _

// whether a vectorized hashmap is used instead
// we have decided to always use the row-based hashmap,
// but the vectorized hashmap can still be switched on for testing and benchmarking purposes.
Expand Down Expand Up @@ -628,6 +638,8 @@ case class HashAggregateExec(
|${consume(ctx, resultVars)}
""".stripMargin
}


ctx.addNewFunction(funcName,
s"""
|private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
Expand Down Expand Up @@ -680,6 +692,10 @@ case class HashAggregateExec(

private def doProduceWithKeys(ctx: CodegenContext): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you apply this optimization only for the with-key case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be only one key for the map inwithout-key case and the optimization will not apply.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see.

val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
avoidSpillInPartialAggregateTerm = ctx.
addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate")
val childrenConsumed = ctx.
addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed")
if (sqlContext.conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap(ctx)
} else if (sqlContext.conf.enableVectorizedHashMap) {
Expand Down Expand Up @@ -750,18 +766,19 @@ case class HashAggregateExec(
finishRegularHashMap
}

outputFunc = generateResultFunction(ctx)
val doAggFuncName = ctx.addNewFunction(doAgg,
s"""
|private void $doAgg() throws java.io.IOException {
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| $childrenConsumed = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before my change, All the records of the the dataset were processed and HashMap populated(+ spilt to sorter) at once.
With my change, all the records are not processed at once due to the change at line#878
Hence we need to know if all the records of the dataset is processed, otherwise process them.
This variable serves the purpose.

| $finishHashMap
|}
""".stripMargin)

// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputFunc = generateResultFunction(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you move this line into the line 771?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val doAggFuncName calls the child's produce, and eventaually will end up in doConsumeWithKeys where I am using it in line 1200 and will be null
Without the change it will not be init-ed.


def outputFromFastHashMap: String = {
if (isFastHashMapEnabled) {
Expand Down Expand Up @@ -833,11 +850,18 @@ case class HashAggregateExec(
s"""
|if (!$initAgg) {
| $initAgg = true;
| $avoidSpillInPartialAggregateTerm =
| ${Utils.isTesting} && $skipPartialAggregate;
| $createFastHashMap
| $hashMapTerm = $thisPlan.createHashMap();
| long $beforeAgg = System.nanoTime();
| $doAggFuncName();
| $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS);
| $shouldStopCheckCode;
|}
|if (!$childrenConsumed) {
| $doAggFuncName();
| $shouldStopCheckCode;
|}
|// output the result
|$outputFromFastHashMap
Expand Down Expand Up @@ -877,44 +901,61 @@ case class HashAggregateExec(
("true", "true", "", "")
}

val oomeClassName = classOf[SparkOutOfMemoryError].getName
val skipPartialAggregateThreshold = sqlContext.conf.skipPartialAggregateThreshold
val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio

val oomeClassName = classOf[SparkOutOfMemoryError].getName
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
val findOrInsertRegularHashMap: String =
s"""
|// generate grouping key
|${unsafeRowKeyCode.code}
|int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
|if ($checkFallbackForBytesToBytesMap) {
| // try to get the buffer from hash map
| $unsafeRowBuffer =
| $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash);
|}
|// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
|// aggregation after processing all input rows.
|if ($unsafeRowBuffer == null) {
| if ($sorterTerm == null) {
| $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
| } else {
| $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
|if (!$avoidSpillInPartialAggregateTerm) {
| // generate grouping key
| ${unsafeRowKeyCode.code}
| int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode();
| if ($checkFallbackForBytesToBytesMap) {
| // try to get the buffer from hash map
| $unsafeRowBuffer =
| $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash);
| }
| $resetCounter
| // the hash map had be spilled, it should have enough memory now,
| // try to allocate buffer again.
| $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
| $unsafeRowKeys, $unsafeRowKeyHash);
| if ($unsafeRowBuffer == null) {
| // failed to allocate the first page
| throw new $oomeClassName("No enough memory for aggregation");
| // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
| // aggregation after processing all input rows.
| if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this check redundant? we are already checking for !avoidSpillInPartialAggregateTerm above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prakharjain09
The first check will kick in once we have decided to avoid partail aggregation, in whih case we wont attempt to fetch from Map.
Second check will kick only when the Map is exhausted completely, and this is when we decide whether we have to skip partil aggregation.
I think both checks are required.

| // If sort/spill to disk is disabled, nothing is done.
| // Aggregation buffer is created later
| $countTerm = $countTerm + $hashMapTerm.getNumRows();
| boolean skipPartAgg =
| !($rowCountTerm < $skipPartialAggregateThreshold) &&
| ($countTerm/$rowCountTerm) > $skipPartialAggRatio;
| if ($skipPartialAggregate && skipPartAgg) {
| $avoidSpillInPartialAggregateTerm = true;
| } else {
| if ($sorterTerm == null) {
| $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
| } else {
| $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
| }
| $resetCounter
| // the hash map had be spilled, it should have enough memory now,
| // try to allocate buffer again.
| $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
| $unsafeRowKeys, $unsafeRowKeyHash);
| if ($unsafeRowBuffer == null) {
| // failed to allocate the first page
| throw new $oomeClassName("No enough memory for aggregation");
| }
| }
| }
|}
""".stripMargin

val partTerm = metricTerm(ctx, "partialAggSkipped")

val findOrInsertHashMap: String = {
if (isFastHashMapEnabled) {
val insertCode = if (isFastHashMapEnabled) {
// If fast hash map is on, we first generate code to probe and update the fast hash map.
// If the probe is successful the corresponding fast row buffer will hold the mutable row.
s"""
|if ($checkFallbackForGeneratedHashMap) {
|if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) {
| ${fastRowKeys.map(_.code).mkString("\n")}
| if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
| $fastRowBuffer = $fastHashMapTerm.findOrInsert(
Expand All @@ -923,12 +964,25 @@ case class HashAggregateExec(
|}
|// Cannot find the key in fast hash map, try regular hash map.
|if ($fastRowBuffer == null) {
| $countTerm = $countTerm + $fastHashMapTerm.getNumRows();
| $findOrInsertRegularHashMap
|}
""".stripMargin
} else {
findOrInsertRegularHashMap
}
val initExpr = declFunctions.flatMap(f => f.initialValues)
val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr)
s"""
|$insertCode
|// Create an empty aggregation buffer
|if ($avoidSpillInPartialAggregateTerm) {
| ${unsafeRowKeyCode.code}
| ${emptyBufferKeyCode.code}
| $unsafeRowBuffer = ${emptyBufferKeyCode.value};
| $partTerm.add(1);
|}
|""".stripMargin
}

val inputAttr = aggregateBufferAttributes ++ inputAttributes
Expand Down Expand Up @@ -1005,7 +1059,7 @@ case class HashAggregateExec(
}

val updateRowInHashMap: String = {
if (isFastHashMapEnabled) {
val updateRowinMap = if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
Expand Down Expand Up @@ -1080,6 +1134,12 @@ case class HashAggregateExec(
} else {
updateRowInRegularHashMap
}
s"""
|$updateRowinMap
|if ($avoidSpillInPartialAggregateTerm) {
| $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer);
|}
|""".stripMargin
}

val declareRowBuffer: String = if (isFastHashMapEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ abstract class HashMapGenerator(
""".stripMargin
}

protected final def generateNumRows(): String = {
s"""
|public int getNumRows() {
| return batch.numRows();
|}
""".stripMargin
}

protected final def genComputeHash(
ctx: CodegenContext,
input: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{AggUtils, HashAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
Expand Down Expand Up @@ -51,6 +51,51 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(df.collect() === Array(Row(9, 4.5)))
}

test(s"Avoid spill in partial aggregation" ) {
withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) {
// Create Dataframes
val data = Seq(("James", 1), ("James", 1), ("Phil", 1))
val aggDF = data.toDF("name", "values").groupBy("name").sum("values")
val partAggNode = aggDF.queryExecution.executedPlan.find {
case h: HashAggregateExec =>
AggUtils.areAggExpressionsPartial(h.aggregateExpressions.map(_.mode))
case _ => false
}
checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1)))
assert(partAggNode.isDefined,
"No HashAggregate node with partial aggregate expression found")
assert(partAggNode.get.metrics("partialAggSkipped").value == data.size,
"Partial aggregation got triggered in partial hash aggregate node")
}
}

test(s"Partial aggregation should not happen when no Aggregate expr" ) {
withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) {
val aggDF = testData2.select(sumDistinct($"a"))
val aggNodes = aggDF.queryExecution.executedPlan.collect {
case h: HashAggregateExec => h
}
checkAnswer(aggDF, Row(6))
assert(aggNodes.nonEmpty)
assert(aggNodes.forall(_.metrics("partialAggSkipped").value == 0))
}
}

test(s"Distinct: Partial aggregation should happen for" +
s" HashAggregate nodes performing partial Aggregate operations " ) {
withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) {
val aggDF = testData2.select(sumDistinct($"a"), sum($"b"))
val aggNodes = aggDF.queryExecution.executedPlan.collect {
case h: HashAggregateExec => h
}
val (baseNodes, other) = aggNodes.partition(_.child.isInstanceOf[SerializeFromObjectExec])
checkAnswer(aggDF, Row(6, 9))
assert(baseNodes.size == 1 )
assert(baseNodes.head.metrics("partialAggSkipped").value == testData2.count())
assert(other.forall(_.metrics("partialAggSkipped").value == 0))
}
}

test("Aggregate with grouping keys should be included in WholeStageCodegen") {
val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2)
val plan = df.queryExecution.executedPlan
Expand Down
Loading