Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark] Add read support for defaultRowCommitVersion #2795

Merged
merged 2 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.util.ScalaExtensions._

import org.apache.spark.sql.catalyst.expressions.FileSourceConstantMetadataStructField
import org.apache.spark.sql.types
import org.apache.spark.sql.types.{LongType, MetadataBuilder, StructField}

object DefaultRowCommitVersion {
def assignIfMissing(
Expand All @@ -35,4 +39,39 @@ object DefaultRowCommitVersion {
a
}
}

def createDefaultRowCommitVersionField(
protocol: Protocol, metadata: Metadata): Option[StructField] = {
Option.when(RowTracking.isEnabled(protocol, metadata)) {
MetadataStructField()
}
}

val METADATA_STRUCT_FIELD_NAME = "default_row_commit_version"

private object MetadataStructField {
private val METADATA_COL_ATTR_KEY = "__default_row_version_metadata_col"

def apply(): StructField =
StructField(
METADATA_STRUCT_FIELD_NAME,
LongType,
nullable = false,
metadata = metadata)

def unapply(field: StructField): Option[StructField] =
Some(field).filter(isValid)

private def metadata: types.Metadata = new MetadataBuilder()
.withMetadata(FileSourceConstantMetadataStructField.metadata(METADATA_STRUCT_FIELD_NAME))
.putBoolean(METADATA_COL_ATTR_KEY, value = true)
.build()


private def isValid(s: StructField): Boolean = {
FileSourceConstantMetadataStructField.isValid(s.dataType, s.metadata) &&
metadata.contains(METADATA_COL_ATTR_KEY) &&
metadata.getBoolean(METADATA_COL_ATTR_KEY)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ case class DeltaParquetFileFormat(
// causes it to run out of the `Integer` range (TODO: Create a SPARK issue)
// For Delta Parquet readers don't expose the row_index field as a metadata field.
super.metadataSchemaFields.filter(field => field != ParquetFileFormat.ROW_INDEX_FIELD) ++
RowId.createBaseRowIdField(protocol, metadata)
RowId.createBaseRowIdField(protocol, metadata) ++
DefaultRowCommitVersion.createDefaultRowCommitVersionField(protocol, metadata)
}

override def prepareWrite(
Expand Down Expand Up @@ -232,8 +233,17 @@ case class DeltaParquetFileFormat(
s"Missing ${RowId.BASE_ROW_ID} value for file '${file.filePath}'")
})
}
val extractDefaultRowCommitVersion: PartitionedFile => Any = { file =>
file.otherConstantMetadataColumnValues
.getOrElse(DefaultRowCommitVersion.METADATA_STRUCT_FIELD_NAME, {
throw new IllegalStateException(
s"Missing ${DefaultRowCommitVersion.METADATA_STRUCT_FIELD_NAME} value " +
s"for file '${file.filePath}'")
})
}
super.fileConstantMetadataExtractors
.updated(RowId.BASE_ROW_ID, extractBaseRowId)
.updated(DefaultRowCommitVersion.METADATA_STRUCT_FIELD_NAME, extractDefaultRowCommitVersion)
}

def copyWithDVInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.Objects
import scala.collection.mutable
import org.apache.spark.sql.delta.RowIndexFilterType
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaErrors, DeltaLog, NoMapping, Snapshot, SnapshotDescriptor}
import org.apache.spark.sql.delta.DefaultRowCommitVersion
import org.apache.spark.sql.delta.RowId
import org.apache.spark.sql.delta.actions.{AddFile, Metadata, Protocol}
import org.apache.spark.sql.delta.implicits._
Expand Down Expand Up @@ -122,6 +123,8 @@ abstract class TahoeFileIndex(
/* path */ absolutePath(addFile.path))
val metadata = mutable.Map.empty[String, Any]
addFile.baseRowId.foreach(baseRowId => metadata.put(RowId.BASE_ROW_ID, baseRowId))
addFile.defaultRowCommitVersion.foreach(defaultRowCommitVersion =>
metadata.put(DefaultRowCommitVersion.METADATA_STRUCT_FIELD_NAME, defaultRowCommitVersion))

FileStatusWithMetadata(fs, metadata.toMap)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ import org.apache.spark.sql.delta.test.DeltaTestImplicits._

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSparkSession

class DefaultRowCommitVersionSuite extends QueryTest
with SharedSparkSession
with ParquetTest
with RowIdTestUtils {
def expectedCommitVersionsForAllFiles(deltaLog: DeltaLog): Map[String, Long] = {
val commitVersionForFiles = mutable.Map.empty[String, Long]
Expand Down Expand Up @@ -218,4 +220,26 @@ class DefaultRowCommitVersionSuite extends QueryTest
}
}
}

test("can read default row commit versions") {
withRowTrackingEnabled(enabled = true) {
withTempDir { tempDir =>
spark.range(start = 0, end = 100, step = 1, numPartitions = 1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)
spark.range(start = 100, end = 200, step = 1, numPartitions = 1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)
spark.range(start = 200, end = 300, step = 1, numPartitions = 1)
.write.format("delta").mode("append").save(tempDir.getAbsolutePath)

withAllParquetReaders {
checkAnswer(
spark.read.format("delta").load(tempDir.getAbsolutePath)
.select("id", "_metadata.default_row_commit_version"),
(0L until 100L).map(Row(_, 0L)) ++
(100L until 200L).map(Row(_, 1L)) ++
(200L until 300L).map(Row(_, 2L)))
}
}
}
}
}
Loading