Skip to content

Commit

Permalink
make sure test is stable
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 10, 2024
1 parent eddb3c7 commit 1a3d20a
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,15 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
private def getLoadedMapForStore(startVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = synchronized {
try {
if (startVersion < 0) {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion || endVersion < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
if (!(startVersion == 0 && endVersion == 0)) {
if (!(endVersion == 0)) {
newMap.putAll(loadMap(startVersion, endVersion))
}
newMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,38 @@ class RocksDB(
logInfo(log"Loading ${MDC(LogKeys.VERSION_NUM, version)}")
try {
if (loadedVersion != version) {
closeDB()
val latestSnapshotVersion = fileManager.getLatestSnapshotVersion(version)
loadFromCheckpoint(latestSnapshotVersion, version)
val metadata = fileManager.loadCheckpointFromDfs(latestSnapshotVersion, workingDir)
loadedVersion = latestSnapshotVersion

// reset last snapshot version
if (lastSnapshotVersion > latestSnapshotVersion) {
// discard any newer snapshots
lastSnapshotVersion = 0L
latestSnapshot = None
}
openDB()

numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) {
// we don't track the total number of rows - discard the number being track
-1L
} else if (metadata.numKeys < 0) {
// we track the total number of rows, but the snapshot doesn't have tracking number
// need to count keys now
countKeys()
} else {
metadata.numKeys
}
if (loadedVersion != version) replayChangelog(version)
// After changelog replay the numKeysOnWritingVersion will be updated to
// the correct number of keys in the loaded version.
numKeysOnLoadedVersion = numKeysOnWritingVersion
fileManagerMetrics = fileManager.latestLoadCheckpointMetrics
}
if (conf.resetStatsOnLoad) {
nativeStats.reset
}

logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)}")
} catch {
case t: Throwable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ private[sql] class RocksDBStateStoreProvider

override def getReadStore(startVersion: Long, endVersion: Long): StateStore = {
try {
if (startVersion < 0) {
if (startVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(startVersion)
}
if (endVersion < startVersion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import java.io.{File, FileWriter}

import org.apache.hadoop.conf.Configuration
import org.scalatest.Assertions

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row}
Expand Down Expand Up @@ -380,10 +379,28 @@ class HDFSBackedStateDataSourceReadSuite
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
classOf[HDFSBackedStateStoreProvider].getName)
// make sure we have a snapshot for every two delta files
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 1)
}

override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider =
new HDFSBackedStateStoreProvider

test("ERROR: snapshot partition not found") {
testPartitionNotFound()
}

test("provider.getReadStore(startVersion, endVersion)") {
testGetReadStoreWithStart()
}

test("option snapshotPartitionId") {
testSnapshotPartitionId()
}

test("snapshotStartBatchId and snapshotPartitionId end to end") {
testSnapshotEndToEnd()
}
}

class RocksDBStateDataSourceReadSuite
Expand All @@ -408,43 +425,66 @@ class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite
classOf[RocksDBStateStoreProvider].getName)
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
"true")
// make sure we have a snapshot for every other checkpoint
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 2)
}

override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

test("ERROR: snapshot partition not found") {
testPartitionNotFound()
}

test("provider.getReadStore(startVersion, endVersion)") {
testGetReadStoreWithStart()
}

test("option snapshotPartitionId") {
testSnapshotPartitionId()
}
}

abstract class StateDataSourceReadSuite[storeProvider <: StateStoreProvider]
extends StateDataSourceTestBase with Assertions {

import testImplicits._
import StateStoreTestsHelper._

protected val keySchema: StructType = StateStoreTestsHelper.keySchema
protected val valueSchema: StructType = StateStoreTestsHelper.valueSchema

protected def newStateStoreProvider(): storeProvider

protected def getNewStateStoreProvider(checkpointDir: String): storeProvider = {
val minDeltasForSnapshot = 1 // overwrites the default 10
val numOfVersToRetainInMemory = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get
val sqlConf = new SQLConf()
sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
sqlConf.setConf(SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY, numOfVersToRetainInMemory)
sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
sqlConf.setConf(SQLConf.STATE_STORE_COMPRESSION_CODEC, SQLConf.get.stateStoreCompressionCodec)
def put(store: StateStore, key1: String, key2: Int, value: Int): Unit = {
store.put(dataToKeyRow(key1, key2), dataToValueRow(value))
}

def get(store: ReadStateStore, key1: String, key2: Int): Option[Int] = {
Option(store.get(dataToKeyRow(key1, key2))).map(valueRowToData)
}

/**
* Calls the overridable [[newStateStoreProvider]] to create the state store provider instance.
* Initialize it with default settings except for STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.
*
* @param checkpointDir path to store state information
* @param minDeltasForSnapshot one snapshot for minDeltasForSnapshot+1 delta files
* @return
*/
private def getNewStateStoreProvider(checkpointDir: String): storeProvider = {
val provider = newStateStoreProvider()
provider.init(
StateStoreId(checkpointDir, 0, 0),
keySchema,
valueSchema,
NoPrefixKeyStateEncoderSpec(keySchema),
useColumnFamilies = false,
StateStoreConf(sqlConf),
StateStoreConf(spark.sessionState.conf),
new Configuration)
provider
}

test("simple aggregation, state ver 1") {
testStreamingAggregation(1)
}
Expand Down Expand Up @@ -912,97 +952,62 @@ abstract class StateDataSourceReadSuite[storeProvider <: StateStoreProvider]
}
}

def put(store: StateStore, key1: String, key2: Int, value: Int): Unit = {
store.put(dataToKeyRow(key1, key2), dataToValueRow(value))
}
protected def testPartitionNotFound(): Unit = {
withTempDir(tempDir => {
val provider = getNewStateStoreProvider(tempDir.getAbsolutePath)
for (i <- 1 to 4) {
val store = provider.getStore(i - 1)
put(store, "a", i, i)
store.commit()
provider.doMaintenance() // create a snapshot every other delta file
}

test("ERROR: snapshot partition not found") {
withTempDir(tempDir1 => {
val tempDir = new java.io.File("/tmp/state/test/")
val exc = intercept[SparkException] {
val provider = getNewStateStoreProvider(tempDir.getAbsolutePath + "/state/")
// val checker = new StateSchemaCompatibilityChecker(
// new StateStoreProviderId(provider.stateStoreId, UUID.randomUUID()), new Configuration())
// checker.createSchemaFile(keySchema, valueSchema)
for (i <- 1 to 4) {
val store = provider.getStore(i - 1)
put(store, "a", 0, i)
store.commit()
provider.doMaintenance() // do cleanup
}
// val stateStore = provider.getStore(0)

// put(stateStore, "a", 1, 1)
// put(stateStore, "b", 2, 2)
// println(stateStore.hasCommitted)
// println(stateStore.getClass.toString)

// stateStore.commit()
provider.close()

// println(stateStore.hasCommitted)
provider.getReadStore(1, 2)
}
checkError(exc, "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED")
})
}

val df = spark.read.format("statestore")
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 0)
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 0)
.option(StateSourceOptions.BATCH_ID, 0)
.load(tempDir.getAbsolutePath)
protected def testGetReadStoreWithStart(): Unit = {
withTempDir(tempDir => {
val provider = getNewStateStoreProvider(tempDir.getAbsolutePath)
for (i <- 1 to 4) {
val store = provider.getStore(i - 1)
put(store, "a", i, i)
store.commit()
provider.doMaintenance()
}

println(df.rdd.getNumPartitions)
val result = provider.getReadStore(2, 3)

assert(get(result, "a", 1).get == 1)
assert(get(result, "a", 2).get == 2)
assert(get(result, "a", 3).get == 3)
assert(get(result, "a", 4).isEmpty)

val result = provider.getReadStore(0, 1)
provider.close()
})
}

protected def testSnapshotPartitionId(): Unit = {
withTempDir(tempDir => {
val inputData = MemoryStream[Int]
val df = inputData.toDF().limit(10)

}
assert(exc.getCause.getMessage.contains(
"CANNOT_LOAD_STATE_STORE.CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS"))
})
testStream(df)(
StartStream(checkpointLocation = tempDir.getAbsolutePath),
AddData(inputData, 1, 2, 3, 4),
CheckLastBatch(1, 2, 3, 4)
)

val exc = intercept[SparkException] {
val checkpointPath = this.getClass.getResource(
"/structured-streaming/checkpoint-version-4.0.0-state-source/").getPath
spark.read.format("statestore")
val stateDf = spark.read.format("statestore")
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 0)
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 0)
.load(checkpointPath).show()
}
assert(exc.getCause.getMessage.contains(
"CANNOT_LOAD_STATE_STORE.CANNOT_READ_SNAPSHOT_FILE_NOT_EXISTS"))
}

test("reconstruct state from specific snapshot and partition") {
val checkpointPath = this.getClass.getResource(
"/structured-streaming/checkpoint-version-4.0.0-state-source/").getPath
val stateFromBatch11 = spark.read.format("statestore")
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 11)
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 1)
.load(checkpointPath)
val stateFromBatch23 = spark.read.format("statestore")
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 23)
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 1)
.load(checkpointPath)
val stateFromLatestBatch = spark.read.format("statestore").load(checkpointPath)
val stateFromLatestBatchPartition1 = stateFromLatestBatch.filter(
stateFromLatestBatch("partition_id") === 1)

checkAnswer(stateFromBatch23, stateFromLatestBatchPartition1)
checkAnswer(stateFromBatch11, stateFromLatestBatchPartition1)
}

test("use snapshotStartBatchId together with batchId") {
val checkpointPath = this.getClass.getResource(
"/structured-streaming/checkpoint-version-4.0.0-state-source/").getPath
val stateFromBatch11 = spark.read.format("statestore")
.option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 11)
.option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 1)
.option(StateSourceOptions.BATCH_ID, 20)
.load(checkpointPath)
val stateFromLatestBatch = spark.read.format("statestore")
.option(StateSourceOptions.BATCH_ID, 20).load(checkpointPath)
val stateFromLatestBatchPartition1 = stateFromLatestBatch.filter(
stateFromLatestBatch("partition_id") === 1)

checkAnswer(stateFromBatch11, stateFromLatestBatchPartition1)
.option(StateSourceOptions.BATCH_ID, 0)
.load(tempDir.getAbsolutePath)

assert(stateDf.rdd.getNumPartitions == 1)
})
}
}

0 comments on commit 1a3d20a

Please sign in to comment.