Skip to content

Commit

Permalink
address reviews by Wei partially
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 13, 2024
1 parent 5229152 commit 4825215
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,10 @@ class StatePartitionReader(
}

private lazy val store: ReadStateStore = {
if (partition.sourceOptions.snapshotStartBatchId.isEmpty) {
provider.getReadStore(partition.sourceOptions.batchId + 1)
}
else {
provider.getReadStore(
partition.sourceOptions.snapshotStartBatchId.get + 1,
partition.sourceOptions.snapshotStartBatchId match {
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)
case Some(snapshotStartBatchId) => provider.getReadStore(
snapshotStartBatchId + 1,
partition.sourceOptions.batchId + 1)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,16 @@ class StateScan(
assert((tail - head + 1) == partitionNums.length,
s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}")

if (sourceOptions.snapshotPartitionId.isEmpty) {
partitionNums.map {
sourceOptions.snapshotPartitionId match {
case None => partitionNums.map {
pn => new StateStoreInputPartition(pn, queryId, sourceOptions)
}.toArray
}
else {
val snapshotPartitionId = sourceOptions.snapshotPartitionId.get
if (partitionNums.contains(snapshotPartitionId)) {
Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions))
} else {
throw QueryExecutionErrors.snapshotPartitionNotFoundError(snapshotPartitionId)
}
case Some(snapshotPartitionId) =>
if (partitionNums.contains(snapshotPartitionId)) {
Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions))
} else {
throw QueryExecutionErrors.snapshotPartitionNotFoundError(snapshotPartitionId)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
if (!(endVersion == 0)) {
if (endVersion != 0) {
newMap.putAll(loadMap(startVersion, endVersion))
}
newMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,19 +374,20 @@ class StateDataSourceSQLConfigSuite extends StateDataSourceTestBase {
}
}

class HDFSBackedStateDataSourceReadSuite
extends StateDataSourceReadSuite[HDFSBackedStateStoreProvider] {
class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite {
override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider =
new HDFSBackedStateStoreProvider

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
classOf[HDFSBackedStateStoreProvider].getName)
// make sure we have a snapshot for every two delta files
// HDFS maintenance task will not count the latest delta file, which has the same version
// as the snapshot version
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 1)
}

override protected def newStateStoreProvider(): HDFSBackedStateStoreProvider =
new HDFSBackedStateStoreProvider

test("ERROR: snapshot of version not found") {
testSnapshotNotFound()
}
Expand All @@ -400,35 +401,35 @@ class HDFSBackedStateDataSourceReadSuite
}
}

class RocksDBStateDataSourceReadSuite
extends StateDataSourceReadSuite[RocksDBStateStoreProvider] {
class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite {
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
classOf[RocksDBStateStoreProvider].getName)
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
"false")
}
}

class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite extends
StateDataSourceReadSuite {
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider
}

class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite
extends StateDataSourceReadSuite[RocksDBStateStoreProvider] {
override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
classOf[RocksDBStateStoreProvider].getName)
newStateStoreProvider().getClass.getName)
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
"true")
// make sure we have a snapshot for every other checkpoint
// RocksDB maintenance task will count the latest checkpoint, so we need to set it to 2
spark.conf.set(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key, 2)
}

override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

test("ERROR: snapshot of version not found") {
testSnapshotNotFound()
}
Expand All @@ -442,16 +443,15 @@ class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite
}
}

abstract class StateDataSourceReadSuite[storeProvider <: StateStoreProvider]
extends StateDataSourceTestBase with Assertions {
abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Assertions {

import testImplicits._
import StateStoreTestsHelper._

protected val keySchema: StructType = StateStoreTestsHelper.keySchema
protected val valueSchema: StructType = StateStoreTestsHelper.valueSchema

protected def newStateStoreProvider(): storeProvider
protected def newStateStoreProvider(): StateStoreProvider

/**
* Calls the overridable [[newStateStoreProvider]] to create the state store provider instance.
Expand All @@ -460,7 +460,7 @@ abstract class StateDataSourceReadSuite[storeProvider <: StateStoreProvider]
* @param checkpointDir path to store state information
* @return instance of class extending [[StateStoreProvider]]
*/
private def getNewStateStoreProvider(checkpointDir: String): storeProvider = {
private def getNewStateStoreProvider(checkpointDir: String): StateStoreProvider = {
val provider = newStateStoreProvider()
provider.init(
StateStoreId(checkpointDir, 0, 0),
Expand Down

0 comments on commit 4825215

Please sign in to comment.