-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-31973][SQL] Skip partial aggregates if grouping keys have high cardinality #28804
Changes from 21 commits
db8a62d
9a59925
feacdcf
ab98ea4
2e102d1
5fa601b
220eaed
452b632
68dd5a3
692fd1b
f1b6ac1
05c891f
dd3c56a
cb8b922
7952aa7
56c95e2
43237ba
99c1d22
d2873a3
afc2903
7766401
75125d9
3ca81ae
8850777
26a2fd6
c088816
c49f106
69f1d71
c9a415d
ceaa4e5
2ae5525
0a186f0
11572a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2196,6 +2196,25 @@ object SQLConf { | |
.checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") | ||
.createWithDefault(16) | ||
|
||
val SKIP_PARTIAL_AGGREGATE_ENABLED = | ||
buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") | ||
.internal() | ||
.doc("Avoid sort/spill to disk during partial aggregation") | ||
.booleanConf | ||
.createWithDefault(true) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we use a threadhold + column stats instead of this boolean config? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didnt get the threshold part. Can you pleas elaborate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That meant a ratio of a distinct row count and total row count in group-by key column stats. For example, if a number There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @maropu for explaining, I will make this change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu This is very useful suggestion. One issue is columns stats are rarely computed. We came across this work in HIVE https://issues.apache.org/jira/browse/HIVE-291. They turn off map side aggregate (i.e., partial aggregate will be pass through) in Physical operator (i.e., Group-By operator) if map-side aggregation reduce the entries by at least half and they look at 100000 rows to do that (ref: patch https://issues.apache.org/jira/secure/attachment/12400257/291.1.txt). Should we do something similar in HashAggregateExec here ? Any thoughts on this ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think whether that approach improves performance depends on IO performance, but the idea looks interesting to me. WDYT? @cloud-fan |
||
|
||
val SKIP_PARTIAL_AGGREGATE_THRESHOLD = | ||
buildConf("spark.sql.aggregate.partialaggregate.skip.threshold") | ||
karuppayya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.internal() | ||
.longConf | ||
.createWithDefault(100000) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it must be data-dependent. If I have 10^10 records, and partial agg cuts it down to 10^8 (1% of the original inputs), it's still worth to do partial agg. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan we skip partial aggregartion only when the aggragation was not able to cut down records by 50%(define by spark.sql.aggregate.partialaggregate.skip.ratio). In this case it will not kick in. |
||
|
||
val SKIP_PARTIAL_AGGREGATE_RATIO = | ||
buildConf("spark.sql.aggregate.partialaggregate.skip.ratio") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you describe more in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated. Can you please check if its explanotory? |
||
.internal() | ||
.doubleConf | ||
.createWithDefault(0.5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need the two params for this optimiation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, could you check performance numbers by varying the params? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu I have borrowed this heuristic from Hive. We can merge them into one. Any suggestions here? |
||
|
||
val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") | ||
.doc("Compression codec used in writing of AVRO files. Supported codecs: " + | ||
"uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") | ||
|
@@ -2922,6 +2941,12 @@ class SQLConf extends Serializable with Logging { | |
|
||
def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) | ||
|
||
def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) | ||
|
||
def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_THRESHOLD) | ||
|
||
def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_RATIO) | ||
|
||
def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) | ||
|
||
def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -353,4 +353,8 @@ object AggUtils { | |
|
||
finalAndCompleteAggregate :: Nil | ||
} | ||
|
||
def areAggExpressionsPartial(modes: Seq[AggregateMode]): Boolean = { | ||
karuppayya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
modes.nonEmpty && modes.forall(_ == Partial) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We canno apply this optimization if the empty case? e.g.,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, the reducer side also does not have any aggregate function and we might end up not aggregating the data |
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,8 @@ case class HashAggregateExec( | |
|
||
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) | ||
|
||
override def needStopCheck: Boolean = skipPartialAggregate | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to add stopCheck to child operators after this change? |
||
|
||
override lazy val allAttributes: AttributeSeq = | ||
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ | ||
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) | ||
|
@@ -72,6 +74,8 @@ case class HashAggregateExec( | |
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), | ||
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), | ||
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), | ||
"partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, | ||
"number of skipped records for partial aggregates"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this metric is only meaningful for aggregates with a partial mode, so could we show it only in the mode? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"avgHashProbe" -> | ||
SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) | ||
|
||
|
@@ -409,6 +413,12 @@ case class HashAggregateExec( | |
private var fastHashMapTerm: String = _ | ||
private var isFastHashMapEnabled: Boolean = false | ||
|
||
private var avoidSpillInPartialAggregateTerm: String = _ | ||
private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate && | ||
karuppayya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is required to avoid this optimization for Query with more than one distinct.
Say the following are the two records in the dataset
With my change
After partial aggregation
Reducer side aggregation result: (key1, 2, 2, 2) Hence checking for the presence of expand node to avoid this skipping partial aggregation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, I see... But, I think the current approach looks dangerous because we might add a new plan having the same assumption in future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True.. Any suggestions on handling this in a better way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no smart idea for that now. Hive doesn't have that kind of plans? |
||
private var rowCountTerm: String = _ | ||
private var outputFunc: String = _ | ||
|
||
// whether a vectorized hashmap is used instead | ||
// we have decided to always use the row-based hashmap, | ||
// but the vectorized hashmap can still be switched on for testing and benchmarking purposes. | ||
|
@@ -628,6 +638,8 @@ case class HashAggregateExec( | |
|${consume(ctx, resultVars)} | ||
""".stripMargin | ||
} | ||
|
||
|
||
ctx.addNewFunction(funcName, | ||
s""" | ||
|private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) | ||
|
@@ -680,6 +692,10 @@ case class HashAggregateExec( | |
|
||
private def doProduceWithKeys(ctx: CodegenContext): String = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you apply this optimization only for the with-key case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There will be only one key for the map inwithout-key case and the optimization will not apply. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. |
||
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") | ||
avoidSpillInPartialAggregateTerm = ctx. | ||
addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate") | ||
val childrenConsumed = ctx. | ||
addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") | ||
if (sqlContext.conf.enableTwoLevelAggMap) { | ||
enableTwoLevelHashMap(ctx) | ||
} else if (sqlContext.conf.enableVectorizedHashMap) { | ||
|
@@ -750,18 +766,19 @@ case class HashAggregateExec( | |
finishRegularHashMap | ||
} | ||
|
||
outputFunc = generateResultFunction(ctx) | ||
val doAggFuncName = ctx.addNewFunction(doAgg, | ||
s""" | ||
|private void $doAgg() throws java.io.IOException { | ||
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
| $childrenConsumed = true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this variable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before my change, All the records of the the dataset were processed and HashMap populated(+ spilt to sorter) at once. |
||
| $finishHashMap | ||
|} | ||
""".stripMargin) | ||
|
||
// generate code for output | ||
val keyTerm = ctx.freshName("aggKey") | ||
val bufferTerm = ctx.freshName("aggBuffer") | ||
val outputFunc = generateResultFunction(ctx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you move this line into the line 771? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def outputFromFastHashMap: String = { | ||
if (isFastHashMapEnabled) { | ||
|
@@ -833,11 +850,18 @@ case class HashAggregateExec( | |
s""" | ||
|if (!$initAgg) { | ||
| $initAgg = true; | ||
| $avoidSpillInPartialAggregateTerm = | ||
| ${Utils.isTesting} && $skipPartialAggregate; | ||
| $createFastHashMap | ||
| $hashMapTerm = $thisPlan.createHashMap(); | ||
| long $beforeAgg = System.nanoTime(); | ||
| $doAggFuncName(); | ||
| $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); | ||
| $shouldStopCheckCode; | ||
|} | ||
|if (!$childrenConsumed) { | ||
| $doAggFuncName(); | ||
| $shouldStopCheckCode; | ||
|} | ||
|// output the result | ||
|$outputFromFastHashMap | ||
|
@@ -877,44 +901,61 @@ case class HashAggregateExec( | |
("true", "true", "", "") | ||
} | ||
|
||
val oomeClassName = classOf[SparkOutOfMemoryError].getName | ||
val skipPartialAggregateThreshold = sqlContext.conf.skipPartialAggregateThreshold | ||
val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio | ||
|
||
karuppayya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val oomeClassName = classOf[SparkOutOfMemoryError].getName | ||
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") | ||
val findOrInsertRegularHashMap: String = | ||
s""" | ||
|// generate grouping key | ||
|${unsafeRowKeyCode.code} | ||
|int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); | ||
|if ($checkFallbackForBytesToBytesMap) { | ||
| // try to get the buffer from hash map | ||
| $unsafeRowBuffer = | ||
| $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); | ||
|} | ||
|// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based | ||
|// aggregation after processing all input rows. | ||
|if ($unsafeRowBuffer == null) { | ||
| if ($sorterTerm == null) { | ||
| $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); | ||
| } else { | ||
| $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); | ||
|if (!$avoidSpillInPartialAggregateTerm) { | ||
| // generate grouping key | ||
| ${unsafeRowKeyCode.code} | ||
| int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); | ||
| if ($checkFallbackForBytesToBytesMap) { | ||
| // try to get the buffer from hash map | ||
| $unsafeRowBuffer = | ||
| $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); | ||
| } | ||
| $resetCounter | ||
| // the hash map had be spilled, it should have enough memory now, | ||
| // try to allocate buffer again. | ||
| $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( | ||
| $unsafeRowKeys, $unsafeRowKeyHash); | ||
| if ($unsafeRowBuffer == null) { | ||
| // failed to allocate the first page | ||
| throw new $oomeClassName("No enough memory for aggregation"); | ||
| // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based | ||
| // aggregation after processing all input rows. | ||
| if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this check redundant? we are already checking for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @prakharjain09 |
||
| // If sort/spill to disk is disabled, nothing is done. | ||
| // Aggregation buffer is created later | ||
| $countTerm = $countTerm + $hashMapTerm.getNumRows(); | ||
| boolean skipPartAgg = | ||
| !($rowCountTerm < $skipPartialAggregateThreshold) && | ||
| ($countTerm/$rowCountTerm) > $skipPartialAggRatio; | ||
| if ($skipPartialAggregate && skipPartAgg) { | ||
| $avoidSpillInPartialAggregateTerm = true; | ||
| } else { | ||
| if ($sorterTerm == null) { | ||
| $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); | ||
| } else { | ||
| $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); | ||
| } | ||
| $resetCounter | ||
| // the hash map had be spilled, it should have enough memory now, | ||
| // try to allocate buffer again. | ||
| $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( | ||
| $unsafeRowKeys, $unsafeRowKeyHash); | ||
| if ($unsafeRowBuffer == null) { | ||
| // failed to allocate the first page | ||
| throw new $oomeClassName("No enough memory for aggregation"); | ||
| } | ||
| } | ||
| } | ||
|} | ||
""".stripMargin | ||
|
||
val partTerm = metricTerm(ctx, "partialAggSkipped") | ||
|
||
val findOrInsertHashMap: String = { | ||
if (isFastHashMapEnabled) { | ||
val insertCode = if (isFastHashMapEnabled) { | ||
// If fast hash map is on, we first generate code to probe and update the fast hash map. | ||
// If the probe is successful the corresponding fast row buffer will hold the mutable row. | ||
s""" | ||
|if ($checkFallbackForGeneratedHashMap) { | ||
|if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) { | ||
| ${fastRowKeys.map(_.code).mkString("\n")} | ||
| if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { | ||
| $fastRowBuffer = $fastHashMapTerm.findOrInsert( | ||
|
@@ -923,12 +964,25 @@ case class HashAggregateExec( | |
|} | ||
|// Cannot find the key in fast hash map, try regular hash map. | ||
|if ($fastRowBuffer == null) { | ||
| $countTerm = $countTerm + $fastHashMapTerm.getNumRows(); | ||
| $findOrInsertRegularHashMap | ||
|} | ||
""".stripMargin | ||
} else { | ||
findOrInsertRegularHashMap | ||
} | ||
val initExpr = declFunctions.flatMap(f => f.initialValues) | ||
val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) | ||
s""" | ||
|$insertCode | ||
|// Create an empty aggregation buffer | ||
|if ($avoidSpillInPartialAggregateTerm) { | ||
| ${unsafeRowKeyCode.code} | ||
| ${emptyBufferKeyCode.code} | ||
| $unsafeRowBuffer = ${emptyBufferKeyCode.value}; | ||
| $partTerm.add(1); | ||
|} | ||
|""".stripMargin | ||
} | ||
|
||
val inputAttr = aggregateBufferAttributes ++ inputAttributes | ||
|
@@ -1005,7 +1059,7 @@ case class HashAggregateExec( | |
} | ||
|
||
val updateRowInHashMap: String = { | ||
if (isFastHashMapEnabled) { | ||
val updateRowinMap = if (isFastHashMapEnabled) { | ||
if (isVectorizedHashMapEnabled) { | ||
ctx.INPUT_ROW = fastRowBuffer | ||
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => | ||
|
@@ -1080,6 +1134,12 @@ case class HashAggregateExec( | |
} else { | ||
updateRowInRegularHashMap | ||
} | ||
s""" | ||
|$updateRowinMap | ||
|if ($avoidSpillInPartialAggregateTerm) { | ||
| $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); | ||
|} | ||
|""".stripMargin | ||
} | ||
|
||
val declareRowBuffer: String = if (isFastHashMapEnabled) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this only works for hash aggregate but not the sort based aggregate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I beleive this heuristic can be applied for sort based aggregation as well. I started with Hash based aggregate, I will create a new PR for sort based aggregation.