Skip to content

Commit

Permalink
[SPARK-48772][SS][SQL] State Data Source Change Feed Reader Mode
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds ability of showing the evolution of state as Change Data Capture (CDC) format to state data source.

An example usage:
```
.format("statestore")
.option("readChangeFeed", true)
.option("changeStartBatchId", 5) #required
.option("changeEndBatchId", 10)  #not required, default: latest batch Id available
```
_Note that this mode does not support the option "joinSide"._

### Why are the changes needed?

The current state reader can only return the entire state at a specific version. If an error occurs related to state, knowing the change of state across versions to find out at which version state starts to go wrong is important for debugging purposes.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Adds a new test suite `StateDataSourceChangeDataReadSuite` that includes 1) testing input error 2) testing new API added 3) integration test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47188 from eason-yuchen-liu/readStateChange.

Lead-authored-by: Yuchen Liu <[email protected]>
Co-authored-by: Yuchen Liu <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
2 people authored and biruktesf-db committed Jul 11, 2024
1 parent c558bd1 commit abb78dc
Show file tree
Hide file tree
Showing 13 changed files with 812 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3835,7 +3835,7 @@
"STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY" : {
"message" : [
"The given State Store Provider <inputClass> does not extend org.apache.spark.sql.execution.streaming.state.SupportsFineGrainedReplay.",
"Therefore, it does not support option snapshotStartBatchId in state data source."
"Therefore, it does not support option snapshotStartBatchId or readChangeFeed in state data source."
],
"sqlState" : "42K06"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -94,10 +94,20 @@ class StateDataSource extends TableProvider with DataSourceRegister {
manager.readSchemaFile()
}

new StructType()
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
if (sourceOptions.readChangeFeed) {
new StructType()
.add("batch_id", LongType)
.add("change_type", StringType)
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
} else {
new StructType()
.add("key", keySchema)
.add("value", valueSchema)
.add("partition_id", IntegerType)
}

} catch {
case NonFatal(e) =>
throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
Expand Down Expand Up @@ -125,21 +135,38 @@ class StateDataSource extends TableProvider with DataSourceRegister {
override def supportsExternalMetadata(): Boolean = false
}

case class FromSnapshotOptions(
snapshotStartBatchId: Long,
snapshotPartitionId: Int)

case class ReadChangeFeedOptions(
changeStartBatchId: Long,
changeEndBatchId: Long
)

case class StateSourceOptions(
resolvedCpLocation: String,
batchId: Long,
operatorId: Int,
storeName: String,
joinSide: JoinSideValues,
snapshotStartBatchId: Option[Long],
snapshotPartitionId: Option[Int]) {
readChangeFeed: Boolean,
fromSnapshotOptions: Option[FromSnapshotOptions],
readChangeFeedOptions: Option[ReadChangeFeedOptions]) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
s"snapshotStartBatchId=${snapshotStartBatchId.getOrElse("None")}, " +
s"snapshotPartitionId=${snapshotPartitionId.getOrElse("None")})"
var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide"
if (fromSnapshotOptions.isDefined) {
desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
}
if (readChangeFeedOptions.isDefined) {
desc += s", changeStartBatchId=${readChangeFeedOptions.get.changeStartBatchId}"
desc += s", changeEndBatchId=${readChangeFeedOptions.get.changeEndBatchId}"
}
desc + ")"
}
}

Expand All @@ -151,6 +178,9 @@ object StateSourceOptions extends DataSourceOptions {
val JOIN_SIDE = newOption("joinSide")
val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId")
val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId")
val READ_CHANGE_FEED = newOption("readChangeFeed")
val CHANGE_START_BATCH_ID = newOption("changeStartBatchId")
val CHANGE_END_BATCH_ID = newOption("changeEndBatchId")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
Expand All @@ -172,16 +202,6 @@ object StateSourceOptions extends DataSourceOptions {
throw StateDataSourceErrors.requiredOptionUnspecified(PATH)
}.get

val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation)

val batchId = Option(options.get(BATCH_ID)).map(_.toLong).orElse {
Some(getLastCommittedBatch(sparkSession, resolvedCpLocation))
}.get

if (batchId < 0) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(BATCH_ID)
}

val operatorId = Option(options.get(OPERATOR_ID)).map(_.toInt)
.orElse(Some(0)).get

Expand Down Expand Up @@ -210,30 +230,97 @@ object StateSourceOptions extends DataSourceOptions {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME))
}

val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong)
if (snapshotStartBatchId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID)
} else if (snapshotStartBatchId.exists(_ > batchId)) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId")
}
val resolvedCpLocation = resolvedCheckpointLocation(hadoopConf, checkpointLocation)

var batchId = Option(options.get(BATCH_ID)).map(_.toLong)

val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong)
val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt)
if (snapshotPartitionId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID)
}

// both snapshotPartitionId and snapshotStartBatchId are required at the same time, because
// each partition may have different checkpoint status
if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID)
} else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID)
val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)

val changeStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong)

var fromSnapshotOptions: Option[FromSnapshotOptions] = None
var readChangeFeedOptions: Option[ReadChangeFeedOptions] = None

if (readChangeFeed) {
if (joinSide != JoinSideValues.none) {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, READ_CHANGE_FEED))
}
if (batchId.isDefined) {
throw StateDataSourceErrors.conflictOptions(Seq(BATCH_ID, READ_CHANGE_FEED))
}
if (snapshotStartBatchId.isDefined) {
throw StateDataSourceErrors.conflictOptions(Seq(SNAPSHOT_START_BATCH_ID, READ_CHANGE_FEED))
}
if (snapshotPartitionId.isDefined) {
throw StateDataSourceErrors.conflictOptions(Seq(SNAPSHOT_PARTITION_ID, READ_CHANGE_FEED))
}

if (changeStartBatchId.isEmpty) {
throw StateDataSourceErrors.requiredOptionUnspecified(CHANGE_START_BATCH_ID)
}
changeEndBatchId = Some(
changeEndBatchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation)))

// changeStartBatchId and changeEndBatchId must all be defined at this point
if (changeStartBatchId.get < 0) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(CHANGE_START_BATCH_ID)
}
if (changeEndBatchId.get < changeStartBatchId.get) {
throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID,
s"$CHANGE_END_BATCH_ID cannot be smaller than $CHANGE_START_BATCH_ID. " +
s"Please check the input to $CHANGE_END_BATCH_ID, or if you are using its default " +
s"value, make sure that $CHANGE_START_BATCH_ID is less than ${changeEndBatchId.get}.")
}

batchId = Some(changeEndBatchId.get)

readChangeFeedOptions = Option(
ReadChangeFeedOptions(changeStartBatchId.get, changeEndBatchId.get))
} else {
if (changeStartBatchId.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(CHANGE_START_BATCH_ID,
s"Only specify this option when $READ_CHANGE_FEED is set to true.")
}
if (changeEndBatchId.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(CHANGE_END_BATCH_ID,
s"Only specify this option when $READ_CHANGE_FEED is set to true.")
}

batchId = Some(batchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation)))

if (batchId.get < 0) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(BATCH_ID)
}
if (snapshotStartBatchId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID)
} else if (snapshotStartBatchId.exists(_ > batchId.get)) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to ${batchId.get}")
}
if (snapshotPartitionId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID)
}
// both snapshotPartitionId and snapshotStartBatchId are required at the same time, because
// each partition may have different checkpoint status
if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID)
} else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID)
}

if (snapshotStartBatchId.isDefined && snapshotPartitionId.isDefined) {
fromSnapshotOptions = Some(
FromSnapshotOptions(snapshotStartBatchId.get, snapshotPartitionId.get))
}
}

StateSourceOptions(
resolvedCpLocation, batchId, operatorId, storeName,
joinSide, snapshotStartBatchId, snapshotPartitionId)
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions)
}

private def resolvedCheckpointLocation(
Expand Down
Loading

0 comments on commit abb78dc

Please sign in to comment.