Skip to content

Commit

Permalink
support reading join states
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 12, 2024
1 parent 61dea35 commit 5229152
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class StreamStreamJoinStatePartitionReader(
partitionId = partition.partition,
formatVersion,
skippedNullValueCount = None,
useStateStoreCoordinator = false
useStateStoreCoordinator = false,
snapshotStartVersion = partition.sourceOptions.snapshotStartBatchId.map(_ + 1)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
new HDFSBackedStateStore(version, newMap)
}

/**
* Get the state store of endVersion for reading by applying delta files on the snapshot of
* startVersion. If startVersion does not exist, an error will be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
override def getStore(startVersion: Long, endVersion: Long): StateStore = {
val newMap = getLoadedMapForStore(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update")
new HDFSBackedStateStore(endVersion, newMap)
}

/** Get the state store for reading to specific `version` of the store. */
override def getReadStore(version: Long): ReadStateStore = {
val newMap = getLoadedMapForStore(version)
Expand All @@ -275,14 +290,13 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return
*/
override def getReadStore(startVersion: Long, endVersion: Long): ReadStateStore = {
val newMap = getLoadedMapForStore(startVersion, endVersion)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, startVersion)} to " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for readonly")
new HDFSBackedReadStateStore(startVersion, newMap)
new HDFSBackedReadStateStore(endVersion, newMap)
}

private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,28 @@ private[sql] class RocksDBStateStoreProvider
}
}

override def getStore(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.load(startVersion, endVersion, readOnly = false)
new RocksDBStateStore(endVersion)
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

override def getReadStore(version: Long): StateStore = {
try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(version, true)
rocksDB.load(version, readOnly = true)
new RocksDBStateStore(version)
}
catch {
Expand All @@ -330,7 +346,7 @@ private[sql] class RocksDBStateStoreProvider
if (endVersion < startVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.load(startVersion, endVersion, true)
rocksDB.load(startVersion, endVersion, readOnly = true)
new RocksDBStateStore(endVersion)
}
catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,21 @@ trait StateStoreProvider {
/** Return an instance of [[StateStore]] representing state data of the given version */
def getStore(version: Long): StateStore

/**
* This is an optional method, used by snapshotStartBatchId option when reading state generated
* by join operation as data source.
* Return an instance of [[StateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
* be thrown.
*
* @param startVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def getStore(startVersion: Long, endVersion: Long): StateStore =
throw new SparkUnsupportedOperationException("getStore with startVersion and endVersion " +
s"is not supported by ${this.getClass.toString}")

/**
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* By default it will return the same instance as getStore(version) but wrapped to prevent
Expand All @@ -379,8 +394,8 @@ trait StateStoreProvider {
new WrappedReadStateStore(getStore(version))

/**
* This is an optional method, used by snapshotStartBatchId option when reading state as data
* source.
* This is an optional method, used by snapshotStartBatchId option when reading state generated
* by all stateful operations except join as data source.
* Return an instance of [[ReadStateStore]] representing state data of the given version.
* The State Store will be constructed from the batch at startVersion, and applying delta files
* up to the endVersion. If there is no snapshot file of batch startVersion, an exception will
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class SymmetricHashJoinStateManager(
partitionId: Int,
stateFormatVersion: Int,
skippedNullValueCount: Option[SQLMetric] = None,
useStateStoreCoordinator: Boolean = true) extends Logging {
useStateStoreCoordinator: Boolean = true,
snapshotStartVersion: Option[Long] = None) extends Logging {
import SymmetricHashJoinStateManager._

/*
Expand Down Expand Up @@ -480,6 +481,8 @@ class SymmetricHashJoinStateManager(
val storeProviderId = StateStoreProviderId(
stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
val store = if (useStateStoreCoordinator) {
assert(snapshotStartVersion.isEmpty, "Should not use state store coordinator " +
"when reading state as data source.")
StateStore.get(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
stateInfo.get.storeVersion, useColumnFamilies = false, storeConf, hadoopConf)
Expand All @@ -489,7 +492,12 @@ class SymmetricHashJoinStateManager(
storeProviderId, keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
useColumnFamilies = false, storeConf, hadoopConf,
useMultipleValuesPerKey = false)
stateStoreProvider.getStore(stateInfo.get.storeVersion)
if (snapshotStartVersion.isDefined) {
stateStoreProvider.getStore(snapshotStartVersion.get, stateInfo.get.storeVersion)
}
else {
stateStoreProvider.getStore(stateInfo.get.storeVersion)
}
}
logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
store
Expand Down

0 comments on commit 5229152

Please sign in to comment.