Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task #21124

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -340,37 +340,35 @@ case class StateStoreSaveExec(
// Update and output modified rows from the StateStore.
case Some(Update) =>

val updatesStartTimeNs = System.nanoTime

new Iterator[InternalRow] {

new NextIterator[InternalRow] {
// Filter late date using watermark if specified
private[this] val baseIterator = watermarkPredicateForData match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}
private val updatesStartTimeNs = System.nanoTime

override def hasNext: Boolean = {
if (!baseIterator.hasNext) {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)

// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
false
override protected def getNext(): InternalRow = {
if (baseIterator.hasNext) {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
numOutputRows += 1
numUpdatedStateRows += 1
row
} else {
true
finished = true
null
}
}

override def next(): InternalRow = {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
store.put(key, row)
numOutputRows += 1
numUpdatedStateRows += 1
row
override protected def close(): Unit = {
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)

// Remove old aggregates if watermark specified
allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
commitTimeMs += timeTakenMs { store.commit() }
setStoreMetrics(store)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") {
// See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
// by ensuring the following.
// - A streaming query with a streaming aggregation.
// - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate.
// - Post shuffle partition has exactly 128 records (i.e. the threshold at which
// ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a
// micro-batch with 128 records that shuffle to a single partition.
// This test throws the exact error reported in SPARK-23004 without the corresponding fix.
withSQLConf("spark.sql.shuffle.partitions" -> "1") {
val input = MemoryStream[Int]
val df = input.toDF().toDF("value")
.selectExpr("value as group", "value")
.groupBy("group")
.agg(collect_list("value"))
testStream(df, outputMode = OutputMode.Update)(
AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
AssertOnQuery { q =>
q.processAllAvailable()
true
}
)
}
}

/** Add blocks of data to the `BlockRDDBackedSource`. */
case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData {
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
Expand Down