Skip to content

Commit

Permalink
check validity of input to options
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jul 1, 2024
1 parent 3834cc9 commit ace711c
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ import org.apache.spark.sql.{RuntimeConfig, SparkSession}
import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, StateDataSourceModeType}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.StateDataSourceModeType.ModeType
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
Expand Down Expand Up @@ -81,11 +80,11 @@ class StateDataSource extends TableProvider with DataSourceRegister {
manager.readSchemaFile()
}

if (sourceOptions.modeType == StateDataSourceModeType.CDC) {
if (sourceOptions.readChangeFeed) {
new StructType()
.add("key", keySchema)
.add("value", valueSchema)
.add("operation_type", StringType)
.add("change_type", StringType)
.add("batch_id", LongType)
.add("partition_id", IntegerType)
} else {
Expand Down Expand Up @@ -130,10 +129,10 @@ case class StateSourceOptions(
storeName: String,
joinSide: JoinSideValues,
snapshotStartBatchId: Option[Long],
snapshotPartitionId: Option[Int],
modeType: ModeType,
cdcStartBatchID: Option[Long],
cdcEndBatchId: Option[Long]) {
snapshotPartitionId: Option[Int],
readChangeFeed: Boolean,
changeStartBatchId: Option[Long],
changeEndBatchId: Option[Long]) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand All @@ -152,31 +151,15 @@ object StateSourceOptions extends DataSourceOptions {
val JOIN_SIDE = newOption("joinSide")
val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId")
val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId")
val MODE_TYPE = newOption("modeType")
val CDC_START_BATCH_ID = newOption("cdcStartBatchId")
val CDC_END_BATCH_ID = newOption("cdcEndBatchId")
val READ_CHANGE_FEED = newOption("readChangeFeed")
val CHANGE_START_BATCH_ID = newOption("cdcStartBatchId")
val CHANGE_END_BATCH_ID = newOption("changeEndBatchId")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
val left, right, none = Value
}

object StateDataSourceModeType extends Enumeration {
type ModeType = Value

val NORMAL = Value("normal")
val CDC = Value("cdc")

// Generate record type from byte representation
def getModeTypeFromString(mode: String): ModeType = {
mode match {
case "normal" => NORMAL
case "cdc" => CDC
case _ => throw new RuntimeException(s"Found invalid mode type for value=$mode")
}
}
}

def apply(
sparkSession: SparkSession,
hadoopConf: Configuration,
Expand Down Expand Up @@ -251,39 +234,34 @@ object StateSourceOptions extends DataSourceOptions {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID)
}

val modeType = Option(options.get(MODE_TYPE)).map(
StateDataSourceModeType.getModeTypeFromString).getOrElse(StateDataSourceModeType.NORMAL)

var cdcStartBatchId = Option(options.get(CDC_START_BATCH_ID)).map(_.toLong)
var cdcEndBatchId = Option(options.get(CDC_END_BATCH_ID)).map(_.toLong)

// if (modeType == StateDataSourceModeType.NORMAL) {
// if (cdcStartBatchId.isDefined) {
// throw StateDataSourceErrors.conflictOptions(Seq(MODE_TYPE, CDC_START_BATCH_ID))
// }
// if (cdcEndBatchId.isDefined) {
// throw StateDataSourceErrors.conflictOptions(Seq(MODE_TYPE, CDC_END_BATCH_ID))
// }
// } else {
// cdcStartBatchId = Option(cdcStartBatchId.getOrElse(
// getFirstCommittedBatch(sparkSession, resolvedCpLocation)))
// cdcEndBatchId = Option(cdcEndBatchId.getOrElse(
// getLastCommittedBatch(sparkSession, resolvedCpLocation)))
//
// if (cdcStartBatchId.isDefined && cdcStartBatchId.get < 0) {
// throw StateDataSourceErrors.invalidOptionValueIsNegative(CDC_START_BATCH_ID)
// }
// if (cdcEndBatchId.isDefined && (cdcEndBatchId.get < 0)) {
// throw
// }
//
// }
val readChangeFeed = Option(options.get(READ_CHANGE_FEED)).exists(_.toBoolean)

val cdcStartBatchId = Option(options.get(CHANGE_START_BATCH_ID)).map(_.toLong)
var changeEndBatchId = Option(options.get(CHANGE_END_BATCH_ID)).map(_.toLong)

if (readChangeFeed) {
if (joinSide != JoinSideValues.none) {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, READ_CHANGE_FEED))
}
if (cdcStartBatchId.isEmpty) {
throw StateDataSourceErrors.requiredOptionUnspecified(CHANGE_START_BATCH_ID)
}
changeEndBatchId = Option(
changeEndBatchId.getOrElse(getLastCommittedBatch(sparkSession, resolvedCpLocation)))
} else {
if (cdcStartBatchId.isDefined) {
throw
StateDataSourceErrors.conflictOptions(Seq(READ_CHANGE_FEED, CHANGE_START_BATCH_ID))
}
if (changeEndBatchId.isDefined) {
throw StateDataSourceErrors.conflictOptions(Seq(READ_CHANGE_FEED, CHANGE_END_BATCH_ID))
}
}

StateSourceOptions(
resolvedCpLocation, batchId, operatorId, storeName,
joinSide, snapshotStartBatchId, snapshotPartitionId,
modeType, cdcStartBatchId, cdcEndBatchId)
readChangeFeed, cdcStartBatchId, changeEndBatchId)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.StateDataSourceModeType
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state._
Expand All @@ -40,7 +39,7 @@ class StatePartitionReaderFactory(

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
if (stateStoreInputPartition.sourceOptions.modeType == StateDataSourceModeType.CDC) {
if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
stateStoreInputPartition, schema)
} else {
Expand Down Expand Up @@ -164,8 +163,8 @@ class StateStoreChangeDataPartitionReader(
}
provider.asInstanceOf[SupportsFineGrainedReplay]
.getStateStoreChangeDataReader(
partition.sourceOptions.cdcStartBatchID.get + 1,
partition.sourceOptions.cdcEndBatchId.get + 1)
partition.sourceOptions.changeStartBatchId.get + 1,
partition.sourceOptions.changeEndBatchId.get + 1)
}

override protected lazy val iter: Iterator[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, StateDataSourceModeType}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
Expand Down Expand Up @@ -74,8 +74,8 @@ class StateTable(
override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

private def isValidSchema(schema: StructType): Boolean = {
if (sourceOptions.modeType == StateDataSourceModeType.CDC) {
return isValidSchemaCDC(schema)
if (sourceOptions.readChangeFeed) {
return isValidChangeDataSchema(schema)
}
if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) {
false
Expand All @@ -90,15 +90,15 @@ class StateTable(
}
}

private def isValidSchemaCDC(schema: StructType): Boolean = {
private def isValidChangeDataSchema(schema: StructType): Boolean = {
if (schema.fieldNames.toImmutableArraySeq !=
Seq("key", "value", "operation_type", "batch_id", "partition_id")) {
Seq("key", "value", "change_type", "batch_id", "partition_id")) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "operation_type").isInstanceOf[StringType]) {
} else if (!SchemaUtil.getSchemaAsDataType(schema, "change_type").isInstanceOf[StringType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "batch_id").isInstanceOf[LongType]) {
false
Expand Down
Loading

0 comments on commit ace711c

Please sign in to comment.