Skip to content

Commit

Permalink
support hdfs state store provider
Browse files Browse the repository at this point in the history
  • Loading branch information
eason-yuchen-liu committed Jun 21, 2024
1 parent 752cdc7 commit cf84d50
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,31 @@ object StateSourceOptions extends DataSourceOptions {

val modeType = Option(options.get(MODE_TYPE)).map(
StateDataSourceModeType.getModeTypeFromString).getOrElse(StateDataSourceModeType.NORMAL)
val cdcStartBatchId = Option(options.get(CDC_START_BATCH_ID)).map(_.toLong)
val cdcEndBatchId = Option(options.get(CDC_END_BATCH_ID)).map(_.toLong)

var cdcStartBatchId = Option(options.get(CDC_START_BATCH_ID)).map(_.toLong)
var cdcEndBatchId = Option(options.get(CDC_END_BATCH_ID)).map(_.toLong)

// if (modeType == StateDataSourceModeType.NORMAL) {
// if (cdcStartBatchId.isDefined) {
// throw StateDataSourceErrors.conflictOptions(Seq(MODE_TYPE, CDC_START_BATCH_ID))
// }
// if (cdcEndBatchId.isDefined) {
// throw StateDataSourceErrors.conflictOptions(Seq(MODE_TYPE, CDC_END_BATCH_ID))
// }
// } else {
// cdcStartBatchId = Option(cdcStartBatchId.getOrElse(
// getFirstCommittedBatch(sparkSession, resolvedCpLocation)))
// cdcEndBatchId = Option(cdcEndBatchId.getOrElse(
// getLastCommittedBatch(sparkSession, resolvedCpLocation)))
//
// if (cdcStartBatchId.isDefined && cdcStartBatchId.get < 0) {
// throw StateDataSourceErrors.invalidOptionValueIsNegative(CDC_START_BATCH_ID)
// }
// if (cdcEndBatchId.isDefined && (cdcEndBatchId.get < 0)) {
// throw
// }
//
// }


StateSourceOptions(
Expand All @@ -279,4 +302,13 @@ object StateSourceOptions extends DataSourceOptions {
case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation)
}
}

private def getFirstCommittedBatch(session: SparkSession, checkpointLocation: String): Long = {
val commitLog = new CommitLog(session,
new Path(checkpointLocation, DIR_NAME_COMMITS).toString)
commitLog.getEarliestBatchId() match {
case Some(firstId) => firstId
case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, StateDataSourceModeType}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -74,7 +74,9 @@ class StateTable(
override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

private def isValidSchema(schema: StructType): Boolean = {
return true
if (sourceOptions.modeType == StateDataSourceModeType.CDC) {
return isValidSchemaCDC(schema)
}
if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) {
Expand All @@ -88,6 +90,25 @@ class StateTable(
}
}

private def isValidSchemaCDC(schema: StructType): Boolean = {
if (schema.fieldNames.toImmutableArraySeq !=
Seq("key", "value", "operation_type", "batch_id", "partition_id")) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "operation_type").isInstanceOf[StringType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "batch_id").isInstanceOf[LongType]) {
false
} else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) {
false
} else {
true
}
}

override def metadataColumns(): Array[MetadataColumn] = Array.empty
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
/** Return the latest batch id without reading the file. */
def getLatestBatchId(): Option[Long] = listBatches.sorted.lastOption

def getEarliestBatchId(): Option[Long] = listBatches.sorted.headOption

override def getLatest(): Option[(Long, T)] = {
listBatches.sorted.lastOption.map { batchId =>
logInfo(log"Getting latest batch ${MDC(BATCH_ID, batchId)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,6 @@ abstract class StateStoreCDCReader(
}
}


// private def decompressStream(inputStream: DataInputStream): DataInputStream = {
// val compressed = compressionCodec.compressedInputStream(inputStream)
// new DataInputStream(compressed)
// }

// private val sourceStream = try {
// fm.open(fileToRead)
// } catch {
// case f: FileNotFoundException =>
// throw QueryExecutionErrors.failedToReadStreamingStateFileError(fileToRead, f)
// }
// protected val input: DataInputStream = decompressStream(sourceStream)


protected lazy val fileIterator =
new ChangeLogFileIterator(stateLocation, startVersion, endVersion)

Expand All @@ -92,7 +77,6 @@ abstract class StateStoreCDCReader(
override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long)

def close(): Unit
// = { if (input != null) input.close() }
}

class HDFSBackedStateStoreCDCReader(
Expand All @@ -112,29 +96,32 @@ class HDFSBackedStateStoreCDCReader(

override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
while (currentChangelogReader == null || !currentChangelogReader.hasNext) {
if (currentChangelogReader != null) {
currentChangelogReader.close()
}
if (!fileIterator.hasNext) {
finished = true
print("return 1\n")
return null
}
currentChangelogReader =
new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec)
}

print("return 2\n")
val readResult = currentChangelogReader.next()
val keyRow = new UnsafeRow(keySchema.fields.length)
keyRow.pointTo(readResult._2, readResult._2.length)
val valueRow = new UnsafeRow(valueSchema.fields.length)
valueRow.pointTo(readResult._3, readResult._3.length)
// If valueSize in existing file is not multiple of 8, floor it to multiple of 8.
// This is a workaround for the following:
// Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
// `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
valueRow.pointTo(readResult._3, (readResult._3.length / 8) * 8)
(readResult._1, keyRow, valueRow, fileIterator.getVersion - 1)
}

// fix the problem when change if -> while will return null



override def close(): Unit = {

if (currentChangelogReader != null) {
currentChangelogReader.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -398,20 +398,6 @@ class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite {
test("option snapshotPartitionId") {
testSnapshotPartitionId()
}

test("just test") {
val provider = getNewStateStoreProvider("/tmp/spark/state")
.asInstanceOf[HDFSBackedStateStoreProvider]
val reader = provider.getStateStoreCDCReader(1, 4)
println(reader.getNext()) // why is the first element null
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())
println(reader.getNext())

}
}

class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite {
Expand Down

0 comments on commit cf84d50

Please sign in to comment.