Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8357] Fix unsafe memory leak on empty inputs in GeneratedAggregate #7560

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,18 @@ case class GeneratedAggregate(

val joinedRow = new JoinedRow3

if (groupingExpressions.isEmpty) {
if (!iter.hasNext) {
// This is an empty input, so return early so that we do not allocate data structures
// that won't be cleaned up (see SPARK-8357).
if (groupingExpressions.isEmpty) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I made a slight simplification compared to @navis's original patch: if groupingExpressions is empty and the input is empty, then always return an empty aggregation buffer. @navis's patch contained an additional branch here which would skip this output if partial = true, but I think that is an unnecessary performance optimization given that the non-generated-Aggregate operator still outputs an empty row even on empty inputs. Removing this branch means fewer cases to have to test.

// This is a global aggregate, so return an empty aggregation buffer.
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
} else {
// This is a grouped aggregate, so return an empty iterator.
Iterator[InternalRow]()
}
} else if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: InternalRow = null
Expand All @@ -280,6 +291,7 @@ case class GeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else if (unsafeEnabled) {
assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(2, 1, 2, 2, 1))
}

test("count of empty table") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this test to make it easier to catch mistakes in the implementation of the groupingExpressions.isEmpty && !iter.hasNext() case. A wrong implementation that does not return an empty buffer will quickly be caught by this test.

withTempTable("t") {
Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t")
checkAnswer(
sql("select count(a) from t"),
Row(0))
}
}

test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext

class AggregateSuite extends SparkPlanTest {

test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some simplifications in this test to remove the big cross-product of options since I think there's only one problematic case that we really need to write a regression test for.

val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED)
try {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true)
val df = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
df,
GeneratedAggregate(
partial = true,
Seq(df.col("b").expr),
Seq(Alias(Count(df.col("a").expr), "cnt")()),
unsafeEnabled = true,
_: SparkPlan),
Seq.empty
)
} finally {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault)
}
}
}