Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23099][SS] Migrate foreach sink to DataSourceV2 #20552

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,148 @@

package org.apache.spark.sql.execution.streaming

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType


case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport {
override def createStreamWriter(
queryId: String,
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
val encoder = encoderFor[T].resolveAndBind(
schema.toAttributes,
SparkSession.getActiveSession.get.sessionState.analyzer)
new StreamWriter with SupportsWriteInternalRow {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
val byteStream = new ByteArrayOutputStream()
val objectStream = new ObjectOutputStream(byteStream)
objectStream.writeObject(writer)
ForeachWriterFactory(byteStream.toByteArray, encoder)
}
}
}
}

case class ForeachWriterFactory[T: Encoder](
serializedWriter: Array[Byte],
encoder: ExpressionEncoder[T])
extends DataWriterFactory[InternalRow] {
override def createDataWriter(partitionId: Int, attemptNumber: Int): ForeachDataWriter[T] = {
new ForeachDataWriter(serializedWriter, encoder, partitionId)
}
}

/**
* A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
* [[ForeachWriter]].
* A [[DataWriter]] for the foreach sink.
*
* Note that [[ForeachWriter]] has the following lifecycle, and (as was true in the V1 sink API)
* assumes that it's never reused:
* * [create writer]
* * open(partitionId, batchId)
* * if open() returned true: write, write, write, ...
* * close()
* while DataSourceV2 writers have a slightly different lifecycle and will be reused for multiple
* epochs in the continuous processing engine:
* * [create writer]
* * write, write, write, ...
* * commit()
*
* @param writer The [[ForeachWriter]] to process all data.
* @tparam T The expected type of the sink.
* The bulk of the implementation here is a shim between these two models.
*
* @param serializedWriter a serialized version of the user-provided [[ForeachWriter]]
* @param encoder encoder from [[Row]] to the type param [[T]]
* @param partitionId the ID of the partition this data writer is responsible for
*
* @tparam T the type of data to be handled by the writer
*/
class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {

override def addBatch(batchId: Long, data: DataFrame): Unit = {
// This logic should've been as simple as:
// ```
// data.as[T].foreachPartition { iter => ... }
// ```
//
// Unfortunately, doing that would just break the incremental planing. The reason is,
// `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will
// create a new plan. Because StreamExecution uses the existing plan to collect metrics and
// update watermark, we should never create a new plan. Otherwise, metrics and watermark are
// updated in the new plan, and StreamExecution cannot retrieval them.
//
// Hence, we need to manually convert internal rows to objects using encoder.
val encoder = encoderFor[T].resolveAndBind(
data.logicalPlan.output,
data.sparkSession.sessionState.analyzer)
data.queryExecution.toRdd.foreachPartition { iter =>
if (writer.open(TaskContext.getPartitionId(), batchId)) {
try {
while (iter.hasNext) {
writer.process(encoder.fromRow(iter.next()))
}
} catch {
case e: Throwable =>
writer.close(e)
throw e
}
writer.close(null)
} else {
writer.close(null)
class ForeachDataWriter[T : Encoder](
serializedWriter: Array[Byte],
encoder: ExpressionEncoder[T],
partitionId: Int)
extends DataWriter[InternalRow] {
private val initialEpochId: Long = {
// Start with the microbatch ID. If it's not there, we're in continuous execution,
// so get the start epoch.
// This ID will be incremented as commits happen.
TaskContext.get().getLocalProperty(MicroBatchExecution.BATCH_ID_KEY) match {
case null => TaskContext.get().getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
case batch => batch.toLong
}
}

// A small state machine representing the lifecycle of the underlying ForeachWriter.
// * CLOSED means close() has been called.
// * OPENED means open() was called and returned true.
// * OPENED_SKIP_PROCESSING means open() was called and returned false.
private object WriterState extends Enumeration {
type WriterState = Value
val CLOSED, OPENED, OPENED_SKIP_PROCESSING = Value
}
import WriterState._

private var writer: ForeachWriter[T] = _
private var state: WriterState = _
private var currentEpochId = initialEpochId

private def openAndSetState(epochId: Long) = {
writer = new ObjectInputStream(new ByteArrayInputStream(serializedWriter)).readObject()
.asInstanceOf[ForeachWriter[T]]

writer.open(partitionId, epochId) match {
case true => state = OPENED
case false => state = OPENED_SKIP_PROCESSING
}
}

openAndSetState(initialEpochId)

override def write(record: InternalRow): Unit = {
try {
state match {
case OPENED => writer.process(encoder.fromRow(record))
case OPENED_SKIP_PROCESSING => ()
case CLOSED =>
// First record of a new epoch, so we need to open a new writer for it.
openAndSetState(currentEpochId)
writer.process(encoder.fromRow(record))
}
} catch {
case t: Throwable =>
writer.close(t)
throw t
}
}

override def commit(): WriterCommitMessage = {
// Close if the writer got opened for this epoch.
state match {
case CLOSED => ()
case _ => writer.close(null)
}
state = CLOSED
currentEpochId += 1
ForeachWriterCommitMessage
}

override def toString(): String = "ForeachSink"
override def abort(): Unit = {}
}

/**
* An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination.
*/
case object ForeachWriterCommitMessage extends WriterCommitMessage
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ class MicroBatchExecution(
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
}

sparkSession.sparkContext.setLocalProperty(
MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString)

reportTimeTaken("queryPlanning") {
lastExecution = new IncrementalExecution(
sparkSessionToRunBatch,
Expand Down Expand Up @@ -507,4 +510,7 @@ class MicroBatchExecution(
}
}

object MicroBatchExecution {
val BATCH_ID_KEY = "sql.streaming.microbatch.batchId"
}
object FakeDataSourceV2 extends DataSourceV2
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc)
val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.streaming

import java.io.Serializable
import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.mutable
Expand All @@ -26,7 +27,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger}
import org.apache.spark.sql.test.SharedSQLContext

class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
Expand Down Expand Up @@ -141,7 +142,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
query.processAllAvailable()
}
assert(e.getCause.isInstanceOf[SparkException])
assert(e.getCause.getCause.getMessage === "error")
assert(e.getCause.getCause.getCause.getMessage === "error")
assert(query.isActive === false)

val allEvents = ForeachSinkSuite.allEvents()
Expand Down Expand Up @@ -255,6 +256,89 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
query.stop()
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be a test with continuous processing + foreach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good instinct, it didn't quite work. Added the test.

testQuietly("foreach does not reuse writers") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]
val query = input.toDS().repartition(1).writeStream
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.foreach(new TestForeachWriter() {
override def process(value: Int): Unit = {
super.process(this.hashCode())
}
}).start()
input.addData(0)
query.processAllAvailable()
input.addData(0)
query.processAllAvailable()

val allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert(allEvents(0)(1).isInstanceOf[ForeachSinkSuite.Process[Int]])
val firstWriterId = allEvents(0)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value
assert(allEvents(1)(1).isInstanceOf[ForeachSinkSuite.Process[Int]])
assert(
allEvents(1)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value != firstWriterId,
"writer was reused!")
}
}

testQuietly("foreach sink for continuous query") {
withTempDir { checkpointDir =>
val query = spark.readStream
.format("rate")
.option("numPartitions", "1")
.option("rowsPerSecond", "5")
.load()
.select('value.cast("INT"))
.map(r => r.getInt(0))
.writeStream
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.trigger(Trigger.Continuous(500))
.foreach(new TestForeachWriter with Serializable {
override def process(value: Int): Unit = {
super.process(this.hashCode())
}
}).start()
try {
// Wait until we get 3 epochs with at least 3 events in them. This means we'll see
// open, close, and at least 1 process.
eventually(timeout(streamingTimeout)) {
// Check
assert(ForeachSinkSuite.allEvents().count(_.size >= 3) === 3)
}

val allEvents = ForeachSinkSuite.allEvents().filter(_.size >= 3)
// Check open and close events.
allEvents(0).head match {
case ForeachSinkSuite.Open(0, _) =>
case e => assert(false, s"unexpected event $e")
}
allEvents(1).head match {
case ForeachSinkSuite.Open(0, _) =>
case e => assert(false, s"unexpected event $e")
}
allEvents(2).head match {
case ForeachSinkSuite.Open(0, _) =>
case e => assert(false, s"unexpected event $e")
}
assert(allEvents(0).last == ForeachSinkSuite.Close(None))
assert(allEvents(1).last == ForeachSinkSuite.Close(None))
assert(allEvents(2).last == ForeachSinkSuite.Close(None))

// Check the first Process event in each epoch, and also check the writer IDs
// we packed in to make sure none got reused.
val writerIds = (0 to 2).map { i =>
allEvents(i)(1).asInstanceOf[ForeachSinkSuite.Process[Int]].value
}
assert(
writerIds.toSet.size == 3,
s"writer was reused! expected 3 unique writers but saw $writerIds")
} finally {
query.stop()
}
}
}
}

/** A global object to collect events in the executor */
Expand Down