Skip to content

Commit

Permalink
feedback and nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Apr 15, 2024
1 parent aa7388d commit ab07117
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
* @return - instance of ListState of type T that can be used to store state persistently
*/
def getListState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T]
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T]

/**
* Creates new or returns existing map state associated with stateName.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ import org.apache.spark.util.NextIterator
* @tparam S - data type of object that will be stored
*/
class ListStateImplWithTTL[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
valEncoder: Encoder[S],
ttlConfig: TTLConfig,
batchTimestampMs: Long)
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
valEncoder: Encoder[S],
ttlConfig: TTLConfig,
batchTimestampMs: Long)
extends SingleKeyTTLStateImpl(stateName, store, batchTimestampMs) with ListState[S] {

private lazy val keySerializer = keyExprEnc.createSerializer()
Expand Down Expand Up @@ -193,9 +193,8 @@ class ListStateImplWithTTL[S](
private[sql] def getWithoutEnforcingTTL(): Iterator[S] = {
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
unsafeRowValuesIterator.map{
valueUnsafeRow =>
stateTypesEncoder.decodeValue(valueUnsafeRow)
unsafeRowValuesIterator.map { valueUnsafeRow =>
stateTypesEncoder.decodeValue(valueUnsafeRow)
}
}

Expand All @@ -205,12 +204,12 @@ class ListStateImplWithTTL[S](
private[sql] def getTTLValues(): Iterator[(S, Long)] = {
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
unsafeRowValuesIterator.map{
valueUnsafeRow =>
(stateTypesEncoder.decodeValue(valueUnsafeRow),
stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
unsafeRowValuesIterator.map { valueUnsafeRow =>
(stateTypesEncoder.decodeValue(valueUnsafeRow),
stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
}
}

/**
* Get all ttl values stored in ttl state for current implicit
* grouping key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ class StatefulProcessorHandleImpl(
* @return - instance of ListState of type T that can be used to store state persistently
*/
override def getListState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T] = {
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T] = {

verifyStateVarOperations("get_list_state")
validateTTLConfig(ttlConfig, stateName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
import java.time.Duration
import java.util.UUID

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException}
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, ListStateImplWithTTL, StatefulProcessorHandleImpl}
Expand Down Expand Up @@ -217,4 +217,31 @@ class ListStateSuite extends StateVariableSuiteBase {
assert(nextBatchTestState.get().isEmpty)
}
}

test("test negative or zero TTL duration throws error") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val batchTimestampMs = 10
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))

Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
val ttlConfig = TTLConfig(ttlDuration)
val ex = intercept[SparkUnsupportedOperationException] {
handle.getListState[String]("testState", Encoders.STRING, ttlConfig)
}

checkError(
ex,
errorClass = "STATEFUL_PROCESSOR_TTL_DURATION_MUST_BE_POSITIVE",
parameters = Map(
"operationType" -> "update",
"stateName" -> "testState"
),
matchPVals = true
)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
}
}

test(s"ttl States are populated for valueState and timeMode=ProcessingTime") {
test("ttl States are populated for valueState and timeMode=ProcessingTime") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
Expand All @@ -237,7 +237,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
}
}

test(s"ttl States are populated for listState and timeMode=ProcessingTime") {
test("ttl States are populated for listState and timeMode=ProcessingTime") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
Expand All @@ -255,7 +255,7 @@ class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
}
}

test(s"ttl States are not populated for timeMode=None") {
test("ttl States are not populated for timeMode=None") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider =>
val store = provider.getStore(0)
val handle = new StatefulProcessorHandleImpl(store,
Expand Down
Loading

0 comments on commit ab07117

Please sign in to comment.