From 08d6f63ce1848f619de115bc476747c92f81636c Mon Sep 17 00:00:00 2001 From: Prakhar Jain Date: Mon, 17 Apr 2023 23:39:43 -0700 Subject: [PATCH] Refactoring around CheckpointMetadata - Move it from Snapshot to LogSegment. GitOrigin-RevId: 200db588b57e8a5fdbf43a1199533a54fb66b4d0 --- .../apache/spark/sql/delta/Checkpoints.scala | 74 ++++++++---- .../org/apache/spark/sql/delta/Snapshot.scala | 7 +- .../spark/sql/delta/SnapshotManagement.scala | 111 +++++++----------- .../sql/delta/SnapshotManagementSuite.scala | 4 +- 4 files changed, 98 insertions(+), 98 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala b/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala index d8cf033d854..861b457cc3a 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala @@ -97,6 +97,11 @@ case class CheckpointMetaData( case Some(_) => CheckpointMetaData.Format.WITH_PARTS case None => CheckpointMetaData.Format.SINGLE } + + /** Whether two [[CheckpointMetaData]] represents the same checkpoint */ + def semanticEquals(other: CheckpointMetaData): Boolean = { + CheckpointInstance(this) == CheckpointInstance(other) + } } object CheckpointMetaData { @@ -248,17 +253,16 @@ object CheckpointMetaData { s""""$result"""" } - def fromLogSegment(segment: LogSegment): Option[CheckpointMetaData] = { - segment.checkpointVersionOpt.map { version => - CheckpointMetaData( - version = version, - size = -1L, - parts = numCheckpointParts(segment.checkpoint.head.getPath), - sizeInBytes = Some(segment.checkpoint.map(_.getLen).sum), - numOfAddFiles = None, - checkpointSchema = None - ) - } + def fromFiles(files: Seq[FileStatus]): CheckpointMetaData = { + assert(files.nonEmpty, "files should be non empty to construct CheckpointMetaData") + CheckpointMetaData( + version = checkpointVersion(files.head), + size = -1L, + parts = numCheckpointParts(files.head.getPath), + sizeInBytes = Some(files.map(_.getLen).sum), + numOfAddFiles = None, + checkpointSchema = None + ) } } @@ -278,13 +282,15 @@ case class CheckpointInstance( s" ${CheckpointMetaData.Format.WITH_PARTS.name}") /** - * Returns a [[CheckpointFileListProvider]] which can tell the files corresponding to this + * Returns a [[CheckpointProvider]] which can tell the files corresponding to this * checkpoint. + * The `checkpointMetadataHint` might be passed to [[CheckpointProvider]] so that underlying + * [[CheckpointProvider]] provides more precise info. */ - def getCheckpointFileListProvider( + def getCheckpointProvider( logPath: Path, - filesForLegacyCheckpointConstruction: Seq[FileStatus], - checkpointMetadataHint: Option[CheckpointMetaData] = None): CheckpointFileListProvider = { + filesForCheckpointConstruction: Seq[FileStatus], + checkpointMetadataHint: Option[CheckpointMetaData] = None): CheckpointProvider = { format match { case CheckpointMetaData.Format.WITH_PARTS | CheckpointMetaData.Format.SINGLE => val filePaths = if (format == CheckpointMetaData.Format.WITH_PARTS) { @@ -293,12 +299,14 @@ case class CheckpointInstance( Set(checkpointFileSingular(logPath, version)) } val newCheckpointFileArray = - filesForLegacyCheckpointConstruction.filter(f => filePaths.contains(f.getPath)) + filesForCheckpointConstruction.filter(f => filePaths.contains(f.getPath)) assert(newCheckpointFileArray.length == filePaths.size, "Failed in getting the file information for:\n" + filePaths.mkString(" -", "\n -", "") + "\namong\n" + - filesForLegacyCheckpointConstruction.map(_.getPath).mkString(" -", "\n -", "")) - PreloadedCheckpointFileProvider(newCheckpointFileArray) + filesForCheckpointConstruction.map(_.getPath).mkString(" -", "\n -", "")) + PreloadedCheckpointProvider( + newCheckpointFileArray, + checkpointMetadataHint.filter(cm => CheckpointInstance(cm) == this)) case CheckpointMetaData.Format.SENTINEL => throw DeltaErrors.assertionFailedError( s"invalid checkpoint format ${CheckpointMetaData.Format.SENTINEL}") @@ -484,7 +492,8 @@ trait Checkpoints extends DeltaLogging { parts = cv.numParts, sizeInBytes = None, numOfAddFiles = None, - checkpointSchema = None) + checkpointSchema = None + ) } /** @@ -731,7 +740,8 @@ object Checkpoints extends DeltaLogging { parts = numPartsOption, sizeInBytes = Some(checkpointSizeInBytes), numOfAddFiles = Some(snapshot.numOfFiles), - checkpointSchema = checkpointSchemaToWriteInLastCheckpointFile(spark, schema)) + checkpointSchema = checkpointSchemaToWriteInLastCheckpointFile(spark, schema) + ) } /** @@ -859,18 +869,32 @@ object CheckpointV2 { } /** - * A trait which provides functionality to retrieve the underlying files for a Checkpoint. + * A trait which provides information about a checkpoint to the Snapshot. + * - files in the underlying checkpoint + * - metadata of the underlying checkpoint */ -trait CheckpointFileListProvider { +trait CheckpointProvider { def checkpointFiles: Seq[FileStatus] + def checkpointMetadata: CheckpointMetaData } /** - * An implementation of [[CheckpointFileListProvider]] where the information about checkpoint files + * An implementation of [[CheckpointProvider]] where the information about checkpoint files * (i.e. Seq[FileStatus]) is already known in advance. + * + * @param checkpointFiles - file statuses for the checkpoint + * @param checkpointMetadataOpt - optional checkpoint metadata for the checkpoint. + * If this is passed, the provider will use it instead of deriving the + * [[CheckpointMetaData]] from the file list. */ -case class PreloadedCheckpointFileProvider( - override val checkpointFiles: Seq[FileStatus]) extends CheckpointFileListProvider { +case class PreloadedCheckpointProvider( + override val checkpointFiles: Seq[FileStatus], + checkpointMetadataOpt: Option[CheckpointMetaData] +) extends CheckpointProvider { + + override def checkpointMetadata: CheckpointMetaData = { + checkpointMetadataOpt.getOrElse(CheckpointMetaData.fromFiles(checkpointFiles)) + } require(checkpointFiles.nonEmpty, "There should be atleast 1 checkpoint file") } diff --git a/core/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala b/core/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala index 54e074d47cf..b79069478db 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala @@ -68,8 +68,8 @@ class Snapshot( val logSegment: LogSegment, override val deltaLog: DeltaLog, val timestamp: Long, - val checksumOpt: Option[VersionChecksum], - checkpointMetadataOpt: Option[CheckpointMetaData] = None) + val checksumOpt: Option[VersionChecksum] + ) extends SnapshotDescriptor with SnapshotStateManager with StateCache @@ -354,7 +354,8 @@ class Snapshot( } } - def getCheckpointMetadataOpt: Option[CheckpointMetaData] = checkpointMetadataOpt + def getCheckpointMetadataOpt: Option[CheckpointMetaData] = + logSegment.checkpointProviderOpt.map(_.checkpointMetadata) def redactedPath: String = Utils.redact(spark.sessionState.conf.stringRedactionPattern, path.toUri.toString) diff --git a/core/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala b/core/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala index a9461d4af42..635c489af90 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala @@ -64,7 +64,11 @@ trait SnapshotManagement { self: DeltaLog => */ protected def getLogSegmentFrom( startingCheckpoint: Option[CheckpointMetaData]): Option[LogSegment] = { - getLogSegmentForVersion(startingCheckpoint.map(_.version)) + getLogSegmentForVersion( + startCheckpoint = startingCheckpoint.map(_.version), + versionToLoad = None, + checkpointMetadataHint = startingCheckpoint + ) } /** Get an iterator of files in the _delta_log directory starting with the startVersion. */ @@ -250,8 +254,8 @@ trait SnapshotManagement { self: DeltaLog => val newVersion = deltasAfterCheckpoint.lastOption.map(deltaVersion).getOrElse(newCheckpoint.get.version) - val checkpointFileListProvider = newCheckpoint - .map(_.getCheckpointFileListProvider(logPath, checkpoints, checkpointMetadataHint)) + val checkpointProvider = newCheckpoint + .map(_.getCheckpointProvider(logPath, checkpoints, checkpointMetadataHint)) // In the case where `deltasAfterCheckpoint` is empty, `deltas` should still not be empty, // they may just be before the checkpoint version unless we have a bug in log cleanup. @@ -270,8 +274,7 @@ trait SnapshotManagement { self: DeltaLog => logPath, newVersion, deltasAfterCheckpoint, - checkpointFileListProvider, - newCheckpoint.map(_.version), + checkpointProvider, lastCommitTimestamp)) } } @@ -287,7 +290,6 @@ trait SnapshotManagement { self: DeltaLog => val lastCheckpointOpt = readLastCheckpointFile() createSnapshotAtInitInternal( initSegment = getLogSegmentFrom(lastCheckpointOpt), - lastCheckpointOpt = lastCheckpointOpt, timestamp = currentTimestamp ) } @@ -295,12 +297,10 @@ trait SnapshotManagement { self: DeltaLog => protected def createSnapshotAtInitInternal( initSegment: Option[LogSegment], - lastCheckpointOpt: Option[CheckpointMetaData], timestamp: Long): CapturedSnapshot = { val snapshot = initSegment.map { segment => val snapshot = createSnapshot( initSegment = segment, - checkpointMetadataOptHint = lastCheckpointOpt, checksumOpt = None) snapshot }.getOrElse { @@ -342,7 +342,6 @@ trait SnapshotManagement { self: DeltaLog => protected def createSnapshot( initSegment: LogSegment, - checkpointMetadataOptHint: Option[CheckpointMetaData], checksumOpt: Option[VersionChecksum]): Snapshot = { val startingFrom = initSegment.checkpointVersionOpt .map(v => s" starting from checkpoint version $v.").getOrElse(".") @@ -354,27 +353,11 @@ trait SnapshotManagement { self: DeltaLog => logSegment = segment, deltaLog = this, timestamp = segment.lastCommitTimestamp, - checksumOpt = checksumOpt.orElse(readChecksum(segment.version)), - checkpointMetadataOpt = getCheckpointMetadataForSegment(segment, checkpointMetadataOptHint)) + checksumOpt = checksumOpt.orElse(readChecksum(segment.version)) + ) } } - /** - * Returns the [[CheckpointMetaData]] for the given [[LogSegment]]. - * If the passed `checkpointMetadataOptHint` matches the `segment`, then it is returned - * directly. - */ - protected def getCheckpointMetadataForSegment( - segment: LogSegment, - checkpointMetadataOptHint: Option[CheckpointMetaData]): Option[CheckpointMetaData] = { - // validate that `checkpointMetadataOptHint` and `segment` has same info regarding the - // checkpoint version and parts. - val checkpointMatches = - (segment.checkpointVersionOpt == checkpointMetadataOptHint.map(_.version)) && - (segment.checkpoint.size == checkpointMetadataOptHint.flatMap(_.parts).getOrElse(1)) - if (checkpointMatches) checkpointMetadataOptHint else CheckpointMetaData.fromLogSegment(segment) - } - /** * Returns a [[LogSegment]] for reading `snapshotVersion` such that the segment's checkpoint * version (if checkpoint present) is LESS THAN `maxExclusiveCheckpointVersion`. @@ -410,8 +393,8 @@ trait SnapshotManagement { self: DeltaLog => } // `checkpoints` may contain multiple checkpoints for different part sizes, we need to // search `FileStatus`s of the checkpoint files for `cp`. - val checkpointFileListProvider = cp.getCheckpointFileListProvider( - logPath, checkpoints, checkpointMetadataHint = None) + val checkpointProvider = + cp.getCheckpointProvider(logPath, checkpoints, checkpointMetadataHint = None) // Create the list of `FileStatus`s for delta files after `cp.version`. val deltasAfterCheckpoint = deltas.filter { file => deltaVersion(file) > cp.version @@ -429,8 +412,7 @@ trait SnapshotManagement { self: DeltaLog => logPath, snapshotVersion, deltas, - Some(checkpointFileListProvider), - Some(cp.version), + Some(checkpointProvider), deltas.last.getModificationTime)) case None => val (deltas, deltaVersions) = @@ -449,14 +431,15 @@ trait SnapshotManagement { self: DeltaLog => logPath = logPath, version = snapshotVersion, deltas = deltas, - checkpoint = Nil, - checkpointVersionOpt = None, + checkpointProviderOpt = None, lastCommitTimestamp = deltas.last.getModificationTime)) } } /** Used to compute the LogSegment after a commit */ - protected[delta] def getLogSegmentAfterCommit(preCommitLogSegment: LogSegment): LogSegment = { + protected[delta] def getLogSegmentAfterCommit( + preCommitLogSegment: LogSegment, + checkpointMetadataHint: Option[CheckpointMetaData]): LogSegment = { /** * We can't specify `versionToLoad = committedVersion` for the call below. * If there are a lot of concurrent commits to the table on the same cluster, each @@ -466,7 +449,11 @@ trait SnapshotManagement { self: DeltaLog => * Instead, just do a general update to the latest available version. The racing commits * can then use the version check short-circuit to avoid constructing a new snapshot. */ - getLogSegmentForVersion(preCommitLogSegment.checkpointVersionOpt).getOrElse { + getLogSegmentForVersion( + startCheckpoint = preCommitLogSegment.checkpointVersionOpt, + versionToLoad = None, + checkpointMetadataHint = checkpointMetadataHint + ).getOrElse { // This shouldn't be possible right after a commit logError(s"No delta log found for the Delta table at $logPath") throw DeltaErrors.emptyDirectoryException(logPath.toString) @@ -658,7 +645,6 @@ trait SnapshotManagement { self: DeltaLog => } else { val newSnapshot = createSnapshot( initSegment = segment, - checkpointMetadataOptHint = previousSnapshot.getCheckpointMetadataOpt, checksumOpt = None) logMetadataTableIdChange(previousSnapshot, newSnapshot) logInfo(s"Updated snapshot to $newSnapshot") @@ -702,12 +688,10 @@ trait SnapshotManagement { self: DeltaLog => protected def createSnapshotAfterCommit( initSegment: LogSegment, newChecksumOpt: Option[VersionChecksum], - committedVersion: Long, - checkpointMetadataOptHint: Option[CheckpointMetaData]): Snapshot = { + committedVersion: Long): Snapshot = { logInfo(s"Creating a new snapshot v${initSegment.version} for commit version $committedVersion") createSnapshot( initSegment, - checkpointMetadataOptHint, checksumOpt = newChecksumOpt ) } @@ -730,7 +714,8 @@ trait SnapshotManagement { self: DeltaLog => // Somebody else could have already updated the snapshot while we waited for the lock if (committedVersion <= previousSnapshot.version) return previousSnapshot val segment = getLogSegmentAfterCommit( - preCommitLogSegment) + preCommitLogSegment, + checkpointMetadataHint = previousSnapshot.getCheckpointMetadataOpt) // This likely implies a list-after-write inconsistency if (segment.version < committedVersion) { @@ -744,8 +729,7 @@ trait SnapshotManagement { self: DeltaLog => val newSnapshot = createSnapshotAfterCommit( segment, newChecksumOpt, - committedVersion, - previousSnapshot.getCheckpointMetadataOpt) + committedVersion) logMetadataTableIdChange(previousSnapshot, newSnapshot) logInfo(s"Updated snapshot to $newSnapshot") replaceSnapshot(newSnapshot, updateTimestamp) @@ -771,7 +755,6 @@ trait SnapshotManagement { self: DeltaLog => getLogSegmentForVersion(startingCheckpoint.map(_.version), Some(version)).map { segment => createSnapshot( initSegment = segment, - checkpointMetadataOptHint = None, checksumOpt = None) }.getOrElse { // We can't return InitialSnapshot because our caller asked for a specific snapshot version. @@ -816,11 +799,19 @@ object SnapshotManagement { def appendCommitToLogSegment( oldLogSegment: LogSegment, commitFileStatus: FileStatus, - committedVersion: Long): LogSegment = { + committedVersion: Long, + checkpointMetadataHint: Option[CheckpointMetaData]): LogSegment = { require(oldLogSegment.version + 1 == committedVersion) + val checkpointProvider = oldLogSegment.checkpointProviderOpt match { + case Some(provider: PreloadedCheckpointProvider) + if checkpointMetadataHint.forall(_.semanticEquals(provider.checkpointMetadata)) => + Some(provider.copy(checkpointMetadataOpt = checkpointMetadataHint)) + case other => other + } oldLogSegment.copy( version = committedVersion, deltas = oldLogSegment.deltas :+ commitFileStatus, + checkpointProviderOpt = checkpointProvider, lastCommitTimestamp = commitFileStatus.getModificationTime) } } @@ -869,9 +860,8 @@ object SerializableFileStatus { * @param logPath The path to the _delta_log directory * @param version The Snapshot version to generate * @param deltas The delta commit files (.json) to read - * @param checkpointFileListProviderOpt provider to give information about Checkpoint files. This - * should be non-empty if [[checkpointVersionOpt]] is present. - * @param checkpointVersionOpt The checkpoint version used to start replay + * @param checkpointProviderOpt provider to give information about Checkpoint files. This + * should be non-empty if [[checkpointVersionOpt]] is present. * @param lastCommitTimestamp The "unadjusted" timestamp of the last commit within this segment. By * unadjusted, we mean that the commit timestamps may not necessarily be * monotonically increasing for the commits within this segment. @@ -880,12 +870,14 @@ case class LogSegment( logPath: Path, version: Long, deltas: Seq[FileStatus], - checkpointFileListProviderOpt: Option[CheckpointFileListProvider], - checkpointVersionOpt: Option[Long], + checkpointProviderOpt: Option[CheckpointProvider], lastCommitTimestamp: Long) { + def checkpointVersionOpt: Option[Long] = + checkpointProviderOpt.map(_.checkpointMetadata.version) + def checkpoint: Seq[FileStatus] = - checkpointFileListProviderOpt.map(_.checkpointFiles).getOrElse(Nil) + checkpointProviderOpt.map(_.checkpointFiles).getOrElse(Nil) override def hashCode(): Int = logPath.hashCode() * 31 + (lastCommitTimestamp % 10000).toInt @@ -905,28 +897,11 @@ case class LogSegment( object LogSegment { - def apply( - logPath: Path, - version: Long, - deltas: Seq[FileStatus], - checkpoint: Seq[FileStatus], - checkpointVersionOpt: Option[Long], - lastCommitTimestamp: Long): LogSegment = { - LogSegment( - logPath, - version, - deltas, - if (checkpoint.nonEmpty) Some(PreloadedCheckpointFileProvider(checkpoint)) else None, - checkpointVersionOpt, - lastCommitTimestamp) - } - /** The LogSegment for an empty transaction log directory. */ def empty(path: Path): LogSegment = LogSegment( logPath = path, version = -1L, deltas = Nil, - checkpoint = Nil, - checkpointVersionOpt = None, + checkpointProviderOpt = None, lastCommitTimestamp = -1L) } diff --git a/core/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala index fc74cdd6f61..78b82b968db 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala @@ -437,9 +437,9 @@ class SnapshotManagementSuite extends QueryTest with SQLTestUtils with SharedSpa val oldLogSegment = log.snapshot.logSegment spark.range(10).write.format("delta").save(path) val newLogSegment = log.snapshot.logSegment - assert(log.getLogSegmentAfterCommit(oldLogSegment) === newLogSegment) + assert(log.getLogSegmentAfterCommit(oldLogSegment, None) === newLogSegment) spark.range(10).write.format("delta").mode("append").save(path) - assert(log.getLogSegmentAfterCommit(newLogSegment) === log.snapshot.logSegment) + assert(log.getLogSegmentAfterCommit(newLogSegment, None) === log.snapshot.logSegment) } }