Skip to content

Commit

Permalink
[SPARK-48770][SS] Change to read operator metadata once on driver to …
Browse files Browse the repository at this point in the history
…check if we can find info for numColsPrefixKey used for session window agg queries

### What changes were proposed in this pull request?
Change to read operator metadata once on driver to check if we can find info for numColsPrefixKey used for session window agg queries

### Why are the changes needed?
Avoid reading the operator metadata file multiple times on the executors

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing unit tests

```
===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.datasources.v2.state.RocksDBStateDataSourceReadSuite, threads: ForkJoinPool.commonPool-worker-6 (daemon=true), ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3 (daemon=true), ForkJoinPool.commonPool-worker-7 (daemon=true), ForkJoinPool.commonPool-worker-5 (daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), rpc-boss-3-1 (daemon=true), ForkJoinPool.commonPool-worker-8 (daemon=true), shuffle-boss-6-1 (daemon=tru...
[info] Run completed in 1 minute, 39 seconds.
[info] Total number of tests run: 14
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 14, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47167 from anishshri-db/task/SPARK-48770.

Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Jul 2, 2024
1 parent ee0d306 commit fea930a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

/**
* An implementation of [[TableProvider]] with [[DataSourceRegister]] for State Store data source.
Expand All @@ -46,6 +48,8 @@ class StateDataSource extends TableProvider with DataSourceRegister {

private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf()

private lazy val serializedHadoopConf = new SerializableConfiguration(hadoopConf)

override def shortName(): String = "statestore"

override def getTable(
Expand All @@ -54,7 +58,17 @@ class StateDataSource extends TableProvider with DataSourceRegister {
properties: util.Map[String, String]): Table = {
val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties)
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
new StateTable(session, schema, sourceOptions, stateConf)
// Read the operator metadata once to see if we can find the information for prefix scan
// encoder used in session window aggregation queries.
val allStateStoreMetadata = new StateMetadataPartitionReader(
sourceOptions.stateCheckpointLocation.getParent.toString, serializedHadoopConf)
.stateMetadata.toArray
val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
entry.operatorId == sourceOptions.operatorId &&
entry.stateStoreName == sourceOptions.storeName
}

new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata)
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.types.StructType
Expand All @@ -33,11 +33,12 @@ import org.apache.spark.util.SerializableConfiguration
class StatePartitionReaderFactory(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
schema: StructType) extends PartitionReaderFactory {
schema: StructType,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
new StatePartitionReader(storeConf, hadoopConf,
partition.asInstanceOf[StateStoreInputPartition], schema)
partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata)
}
}

Expand All @@ -49,7 +50,9 @@ class StatePartitionReader(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
schema: StructType) extends PartitionReader[InternalRow] with Logging {
schema: StructType,
stateStoreMetadata: Array[StateMetadataTableEntry])
extends PartitionReader[InternalRow] with Logging {

private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType]
Expand All @@ -58,13 +61,6 @@ class StatePartitionReader(
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
val allStateStoreMetadata = new StateMetadataPartitionReader(
partition.sourceOptions.stateCheckpointLocation.getParent.toString, hadoopConf)
.stateMetadata.toArray
val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
entry.operatorId == partition.sourceOptions.operatorId &&
entry.stateStoreName == partition.sourceOptions.storeName
}
val numColsPrefixKey = if (stateStoreMetadata.isEmpty) {
logWarning("Metadata for state store not found, possible cause is this checkpoint " +
"is created by older version of spark. If the query has session window aggregation, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, StateStoreErrors}
import org.apache.spark.sql.types.StructType
Expand All @@ -35,8 +36,10 @@ class StateScanBuilder(
session: SparkSession,
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf)
stateStoreConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
stateStoreMetadata)
}

/** An implementation of [[InputPartition]] for State Store data source. */
Expand All @@ -50,7 +53,8 @@ class StateScan(
session: SparkSession,
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf) extends Scan with Batch {
stateStoreConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
private val hadoopConfBroadcast = session.sparkContext.broadcast(
Expand All @@ -62,7 +66,8 @@ class StateScan(
val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value)
val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() {
override def accept(path: Path): Boolean = {
fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
fs.getFileStatus(path).isDirectory &&
Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
}
})

Expand Down Expand Up @@ -116,7 +121,8 @@ class StateScan(
hadoopConfBroadcast.value, userFacingSchema, stateSchema)

case JoinSideValues.none =>
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema)
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
stateStoreMetadata)
}

override def toBatch: Batch = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.{IntegerType, StructType}
Expand All @@ -35,7 +36,8 @@ class StateTable(
session: SparkSession,
override val schema: StructType,
sourceOptions: StateSourceOptions,
stateConf: StateStoreConf)
stateConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry])
extends Table with SupportsRead with SupportsMetadataColumns {

import StateTable._
Expand Down Expand Up @@ -69,7 +71,7 @@ class StateTable(
override def capabilities(): util.Set[TableCapability] = CAPABILITY

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new StateScanBuilder(session, schema, sourceOptions, stateConf)
new StateScanBuilder(session, schema, sourceOptions, stateConf, stateStoreMetadata)

override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

Expand Down

0 comments on commit fea930a

Please sign in to comment.