Skip to content

Commit

Permalink
Merge branch 'apache:master' into skipSnapshotAtBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu authored Jul 2, 2024
2 parents 8fa9ef5 + fea930a commit 9af25f1
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

/**
* An implementation of [[TableProvider]] with [[DataSourceRegister]] for State Store data source.
Expand All @@ -46,6 +48,8 @@ class StateDataSource extends TableProvider with DataSourceRegister {

private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf()

private lazy val serializedHadoopConf = new SerializableConfiguration(hadoopConf)

override def shortName(): String = "statestore"

override def getTable(
Expand All @@ -54,7 +58,17 @@ class StateDataSource extends TableProvider with DataSourceRegister {
properties: util.Map[String, String]): Table = {
val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties)
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
new StateTable(session, schema, sourceOptions, stateConf)
// Read the operator metadata once to see if we can find the information for prefix scan
// encoder used in session window aggregation queries.
val allStateStoreMetadata = new StateMetadataPartitionReader(
sourceOptions.stateCheckpointLocation.getParent.toString, serializedHadoopConf)
.stateMetadata.toArray
val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
entry.operatorId == sourceOptions.operatorId &&
entry.stateStoreName == sourceOptions.storeName
}

new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata)
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.types.StructType
Expand All @@ -33,11 +33,12 @@ import org.apache.spark.util.SerializableConfiguration
class StatePartitionReaderFactory(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
schema: StructType) extends PartitionReaderFactory {
schema: StructType,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
new StatePartitionReader(storeConf, hadoopConf,
partition.asInstanceOf[StateStoreInputPartition], schema)
partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata)
}
}

Expand All @@ -49,7 +50,9 @@ class StatePartitionReader(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
schema: StructType) extends PartitionReader[InternalRow] with Logging {
schema: StructType,
stateStoreMetadata: Array[StateMetadataTableEntry])
extends PartitionReader[InternalRow] with Logging {

private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType]
Expand All @@ -58,13 +61,6 @@ class StatePartitionReader(
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
val allStateStoreMetadata = new StateMetadataPartitionReader(
partition.sourceOptions.stateCheckpointLocation.getParent.toString, hadoopConf)
.stateMetadata.toArray
val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
entry.operatorId == partition.sourceOptions.operatorId &&
entry.stateStoreName == partition.sourceOptions.storeName
}
val numColsPrefixKey = if (stateStoreMetadata.isEmpty) {
logWarning("Metadata for state store not found, possible cause is this checkpoint " +
"is created by older version of spark. If the query has session window aggregation, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, StateStoreErrors}
import org.apache.spark.sql.types.StructType
Expand All @@ -35,8 +36,10 @@ class StateScanBuilder(
session: SparkSession,
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf)
stateStoreConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
stateStoreMetadata)
}

/** An implementation of [[InputPartition]] for State Store data source. */
Expand All @@ -50,7 +53,8 @@ class StateScan(
session: SparkSession,
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf) extends Scan with Batch {
stateStoreConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
private val hadoopConfBroadcast = session.sparkContext.broadcast(
Expand All @@ -62,7 +66,8 @@ class StateScan(
val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value)
val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() {
override def accept(path: Path): Boolean = {
fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
fs.getFileStatus(path).isDirectory &&
Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
}
})

Expand Down Expand Up @@ -116,7 +121,8 @@ class StateScan(
hadoopConfBroadcast.value, userFacingSchema, stateSchema)

case JoinSideValues.none =>
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema)
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
stateStoreMetadata)
}

override def toBatch: Batch = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.{IntegerType, StructType}
Expand All @@ -35,7 +36,8 @@ class StateTable(
session: SparkSession,
override val schema: StructType,
sourceOptions: StateSourceOptions,
stateConf: StateStoreConf)
stateConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry])
extends Table with SupportsRead with SupportsMetadataColumns {

import StateTable._
Expand Down Expand Up @@ -69,7 +71,7 @@ class StateTable(
override def capabilities(): util.Set[TableCapability] = CAPABILITY

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new StateScanBuilder(session, schema, sourceOptions, stateConf)
new StateScanBuilder(session, schema, sourceOptions, stateConf, stateStoreMetadata)

override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

Expand Down

0 comments on commit 9af25f1

Please sign in to comment.