Skip to content

Commit

Permalink
refactor the code to isolate from current state stores used by stream…
Browse files Browse the repository at this point in the history
…ing queries
  • Loading branch information
eason-yuchen-liu committed Jun 25, 2024
1 parent 876256e commit 1a23abb
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -96,9 +96,10 @@ class StatePartitionReader(
partition.sourceOptions.snapshotStartBatchId match {
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)

case Some(snapshotStartBatchId) => provider.getReadStore(
snapshotStartBatchId + 1,
partition.sourceOptions.batchId + 1)
case Some(snapshotStartBatchId) =>
provider.asInstanceOf[FineGrainedStateSource].replayReadStoreFromSnapshot(
snapshotStartBatchId + 1,
partition.sourceOptions.batchId + 1)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ import org.apache.spark.util.ArrayImplicits._
* to ensure re-executed RDD operations re-apply updates on the correct past version of the
* store.
*/
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging {
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging
with FineGrainedStateSource {

private val providerName = "HDFSBackedStateStoreProvider"

Expand Down Expand Up @@ -269,8 +270,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedStateStore]]
*/
override def getStore(startVersion: Long, endVersion: Long): StateStore = {
val newMap = getLoadedMapForStore(startVersion, endVersion)
override def replayStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
val newMap = replayLoadedMapForStoreFromSnapshot(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update")
Expand All @@ -293,8 +294,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedReadStateStore]]
*/
override def getReadStore(startVersion: Long, endVersion: Long): ReadStateStore = {
val newMap = getLoadedMapForStore(startVersion, endVersion)
override def replayReadStoreFromSnapshot(startVersion: Long, endVersion: Long): ReadStateStore = {
val newMap = replayLoadedMapForStoreFromSnapshot(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for readonly")
Expand Down Expand Up @@ -323,7 +324,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
private def getLoadedMapForStore(startVersion: Long, endVersion: Long):
private def replayLoadedMapForStoreFromSnapshot(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = synchronized {
try {
if (startVersion < 1) {
Expand All @@ -335,7 +336,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
if (endVersion != 0) {
newMap.putAll(loadMap(startVersion, endVersion))
newMap.putAll(constructMapFromSnapshot(startVersion, endVersion))
}
newMap
}
Expand Down Expand Up @@ -603,25 +604,17 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
result
}

private def loadMap(startVersion: Long, endVersion: Long): HDFSBackedStateStoreMap = {
private def constructMapFromSnapshot(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = {
val (result, elapsedMs) = Utils.timeTakenMs {
val startVersionMap = synchronized { Option(loadedMaps.get(startVersion)) } match {
case Some(value) =>
loadedMapCacheHitCount.increment()
Option(value)
case None =>
logWarning(
log"The state for version ${MDC(LogKeys.FILE_VERSION, startVersion)} doesn't " +
log"exist in loadedMaps. Reading snapshot file and delta files if needed..." +
log"Note that this is normal for the first batch of starting query.")
loadedMapCacheMissCount.increment()
readSnapshotFile(startVersion)
case Some(value) => Option(value)
case None => readSnapshotFile(startVersion)
}
if (startVersionMap.isEmpty) {
throw StateStoreErrors.stateStoreSnapshotFileNotFound(
snapshotFile(startVersion).toString, toString())
}
synchronized { putStateIntoStateCacheMap(startVersion, startVersionMap.get) }

// Load all the deltas from the version after the start version up to the end version.
val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,18 +233,19 @@ class RocksDB(
*
* @param startVersion version of the snapshot to start with
* @param endVersion end version
* @param readOnly whether the RocksDB instance is read-only
* @return A RocksDB instance loaded with the state endVersion replayed from startVersion
* @return A RocksDB instance loaded with the state endVersion replayed from startVersion.
* Note that the instance will be read-only since this method is only used in State Data
* Source.
*/
def load(startVersion: Long, endVersion: Long, readOnly: Boolean): RocksDB = {
def loadFromSnapshot(startVersion: Long, endVersion: Long): RocksDB = {
assert(startVersion >= 0 && endVersion >= startVersion)
acquire(LoadStore)
recordedMetrics = None
logInfo(
log"Loading ${MDC(LogKeys.VERSION_NUM, endVersion)} from " +
log"${MDC(LogKeys.VERSION_NUM, startVersion)}")
try {
loadFromCheckpoint(startVersion, endVersion)
replayFromCheckpoint(startVersion, endVersion)

logInfo(
log"Loaded ${MDC(LogKeys.VERSION_NUM, endVersion)} from " +
Expand All @@ -254,11 +255,6 @@ class RocksDB(
loadedVersion = -1 // invalidate loaded data
throw t
}
if (enableChangelogCheckpointing && !readOnly) {
// Make sure we don't leak resource.
changelogWriter.foreach(_.abort())
changelogWriter = Some(fileManager.getChangeLogWriter(endVersion + 1, useColumnFamilies))
}
this
}

Expand All @@ -270,7 +266,7 @@ class RocksDB(
* @param startVersion start checkpoint version
* @param endVersion end version
*/
def loadFromCheckpoint(startVersion: Long, endVersion: Long): Any = {
private def replayFromCheckpoint(startVersion: Long, endVersion: Long): Any = {
if (loadedVersion != startVersion) {
closeDB()
val metadata = fileManager.loadCheckpointFromDfs(startVersion, workingDir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

private[sql] class RocksDBStateStoreProvider
extends StateStoreProvider with Logging with Closeable {
extends StateStoreProvider with Logging with Closeable with FineGrainedStateSource {
import RocksDBStateStoreProvider._

class RocksDBStateStore(lastVersion: Long) extends StateStore {
Expand Down Expand Up @@ -309,15 +309,15 @@ private[sql] class RocksDBStateStoreProvider
}
}

override def getStore(startVersion: Long, endVersion: Long): StateStore = {
override def replayStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.load(startVersion, endVersion, readOnly = false)
rocksDB.loadFromSnapshot(startVersion, endVersion)
new RocksDBStateStore(endVersion)
}
catch {
Expand All @@ -338,15 +338,15 @@ private[sql] class RocksDBStateStoreProvider
}
}

override def getReadStore(startVersion: Long, endVersion: Long): StateStore = {
override def replayReadStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.load(startVersion, endVersion, readOnly = true)
rocksDB.loadFromSnapshot(startVersion, endVersion)
new RocksDBStateStore(endVersion)
}
catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,21 +359,6 @@ trait StateStoreProvider {
/** Return an instance of [[StateStore]] representing state data of the given version */
def getStore(version: Long): StateStore

/**
* This is an optional method, used by snapshotStartBatchId option when reading state generated
* by join operation as data source.
* Return an instance of [[StateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def getStore(startVersion: Long, endVersion: Long): StateStore =
throw new SparkUnsupportedOperationException("getStore with startVersion and endVersion " +
s"is not supported by ${this.getClass.toString}")

/**
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* By default it will return the same instance as getStore(version) but wrapped to prevent
Expand All @@ -383,21 +368,6 @@ trait StateStoreProvider {
def getReadStore(version: Long): ReadStateStore =
new WrappedReadStateStore(getStore(version))

/**
* This is an optional method, used by snapshotStartBatchId option when reading state generated
* by all stateful operations except join as data source.
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def getReadStore(startVersion: Long, endVersion: Long): ReadStateStore =
throw new SparkUnsupportedOperationException("getReadStore with startVersion and endVersion " +
s"is not supported by ${this.getClass.toString}")

/** Optional method for providers to allow for background maintenance (e.g. compactions) */
def doMaintenance(): Unit = { }

Expand Down Expand Up @@ -469,6 +439,39 @@ object StateStoreProvider {
}
}

/**
* This is an optional trait to be implemented by [[StateStoreProvider]]s that can read fine
* grained state data which is replayed from a specific snapshot version. It is used by the
* snapshotStartBatchId option in state data source.
*/
trait FineGrainedStateSource {
/**
* Used by snapshotStartBatchId option when reading state generated by join operation as data
* source.
* Return an instance of [[StateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def replayStoreFromSnapshot(startVersion: Long, endVersion: Long): StateStore

/**
* Used by snapshotStartBatchId option when reading state generated by all stateful operations
* except join as data source.
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def replayReadStoreFromSnapshot(startVersion: Long, endVersion: Long): ReadStateStore
}

/**
* Unique identifier for a provider, used to identify when providers can be reused.
* Note that `queryRunId` is used uniquely identify a provider, so that the same provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,8 @@ class SymmetricHashJoinStateManager(
useColumnFamilies = false, storeConf, hadoopConf,
useMultipleValuesPerKey = false)
if (snapshotStartVersion.isDefined) {
stateStoreProvider.getStore(snapshotStartVersion.get, stateInfo.get.storeVersion)
stateStoreProvider.asInstanceOf[FineGrainedStateSource]
.replayStoreFromSnapshot(snapshotStartVersion.get, stateInfo.get.storeVersion)
} else {
stateStoreProvider.getStore(stateInfo.get.storeVersion)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
}

val exc = intercept[SparkException] {
provider.getReadStore(1, 2)
provider.asInstanceOf[FineGrainedStateSource].replayReadStoreFromSnapshot(1, 2)
}
checkError(exc, "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED")
})
Expand All @@ -1001,7 +1001,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass
provider.doMaintenance()
}

val result = provider.getReadStore(2, 3)
val result = provider.asInstanceOf[FineGrainedStateSource].replayReadStoreFromSnapshot(2, 3)

assert(get(result, "a", 1).get == 1)
assert(get(result, "a", 2).get == 2)
Expand Down

0 comments on commit 1a23abb

Please sign in to comment.