Skip to content

Commit

Permalink
Support limit pushdown on Delta tables with DVs
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 6ef96d9f59ce3915020856d67a35422f38c5ae85
  • Loading branch information
vkorukanti committed Jan 27, 2023
1 parent 2d7d625 commit 52c221a
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,33 @@ package org.apache.spark.sql.delta.actions
import java.net.URI

/**
* Replays a history of action, resolving them to produce the current state
* Replays a history of actions, resolving them to produce the current state
* of the table. The protocol for resolution is as follows:
* - The most recent [[AddFile]] and accompanying metadata for any `path` wins.
* - The most recent [[AddFile]] and accompanying metadata for any `(path, dv id)` tuple wins.
* - [[RemoveFile]] deletes a corresponding [[AddFile]] and is retained as a
* tombstone until `minFileRetentionTimestamp` has passed.
* A [[RemoveFile]] "corresponds" to the [[AddFile]] that matches both the parquet file URI
* *and* the deletion vector's URI (if any).
* - The most recent version for any `appId` in a [[SetTransaction]] wins.
* - The most recent [[Metadata]] wins.
* - The most recent [[Protocol]] version wins.
* - For each path, this class should always output only one [[FileAction]] (either [[AddFile]] or
* [[RemoveFile]])
* - For each `(path, dv id)` tuple, this class should always output only one [[FileAction]]
* (either [[AddFile]] or [[RemoveFile]])
*
* This class is not thread safe.
*/
class InMemoryLogReplay(
minFileRetentionTimestamp: Long,
minSetTransactionRetentionTimestamp: Option[Long]) extends LogReplay {

var currentProtocolVersion: Protocol = null
var currentVersion: Long = -1
var currentMetaData: Metadata = null
val transactions = new scala.collection.mutable.HashMap[String, SetTransaction]()
val activeFiles = new scala.collection.mutable.HashMap[URI, AddFile]()
private val tombstones = new scala.collection.mutable.HashMap[URI, RemoveFile]()
import InMemoryLogReplay._

private var currentProtocolVersion: Protocol = null
private var currentVersion: Long = -1
private var currentMetaData: Metadata = null
private val transactions = new scala.collection.mutable.HashMap[String, SetTransaction]()
private val activeFiles = new scala.collection.mutable.HashMap[UniqueFileActionTuple, AddFile]()
private val tombstones = new scala.collection.mutable.HashMap[UniqueFileActionTuple, RemoveFile]()

override def append(version: Long, actions: Iterator[Action]): Unit = {
assert(currentVersion == -1 || version == currentVersion + 1,
Expand All @@ -55,14 +59,16 @@ class InMemoryLogReplay(
case a: Protocol =>
currentProtocolVersion = a
case add: AddFile =>
activeFiles(add.pathAsUri) = add.copy(dataChange = false)
val uniquePath = UniqueFileActionTuple(add.pathAsUri, add.getDeletionVectorUniqueId)
activeFiles(uniquePath) = add.copy(dataChange = false)
// Remove the tombstone to make sure we only output one `FileAction`.
tombstones.remove(add.pathAsUri)
tombstones.remove(uniquePath)
case remove: RemoveFile =>
activeFiles.remove(remove.pathAsUri)
tombstones(remove.pathAsUri) = remove.copy(dataChange = false)
case ci: CommitInfo => // do nothing
case cdc: AddCDCFile => // do nothing
val uniquePath = UniqueFileActionTuple(remove.pathAsUri, remove.getDeletionVectorUniqueId)
activeFiles.remove(uniquePath)
tombstones(uniquePath) = remove.copy(dataChange = false)
case _: CommitInfo => // do nothing
case _: AddCDCFile => // do nothing
case null => // Some crazy future feature. Ignore
}
}
Expand All @@ -71,7 +77,7 @@ class InMemoryLogReplay(
tombstones.values.filter(_.delTimestamp > minFileRetentionTimestamp)
}

private def getTransactions: Iterable[SetTransaction] = {
private[delta] def getTransactions: Iterable[SetTransaction] = {
if (minSetTransactionRetentionTimestamp.isEmpty) {
transactions.values
} else {
Expand All @@ -88,4 +94,12 @@ class InMemoryLogReplay(
getTransactions ++
(activeFiles.values ++ getTombstones).toSeq.sortBy(_.path).iterator
}

/** Returns all [[AddFile]] actions after the Log Replay */
private[delta] def allFiles: Seq[AddFile] = activeFiles.values.toSeq
}

object InMemoryLogReplay{
/** The unit of path uniqueness in delta log actions is the tuple `(parquet file, dv)`. */
final case class UniqueFileActionTuple(fileURI: URI, deletionVectorURI: Option[String])
}
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,30 @@ case class AddFile(
removedFile
}

/**
* Logically remove rows by associating a `deletionVector` with the file.
* @param deletionVector: The descriptor of the DV that marks rows as deleted.
* @param dataChange: When false, the actions are marked as no-data-change actions.
*/
def removeRows(
deletionVector: DeletionVectorDescriptor,
dataChange: Boolean = true): (AddFile, RemoveFile) = {
val withUpdatedDV = this.copy(deletionVector = deletionVector, dataChange = dataChange)
val addFile = withUpdatedDV
val removeFile = this.removeWithTimestamp(dataChange = dataChange)
(addFile, removeFile)
}

/**
* Return the unique id of the deletion vector, if present, or `None` if there's no DV.
*
* The unique id differentiates DVs, even if there are multiple in the same file
* or the DV is stored inline.
*/
@JsonIgnore
def getDeletionVectorUniqueId: Option[String] = Option(deletionVector).map(_.uniqueId)


@JsonIgnore
lazy val insertionTime: Long = tag(AddFile.Tags.INSERTION_TIME).map(_.toLong)
// From modification time in milliseconds to microseconds.
Expand Down Expand Up @@ -448,6 +472,7 @@ case class AddFile(

val numLogicalRecords = if (node.has("numRecords")) {
Some(node.get("numRecords")).filterNot(_.isNull).map(_.asLong())
.map(_ - numDeletedRecords)
} else None

Some(ParsedStatsFields(
Expand All @@ -461,6 +486,13 @@ case class AddFile(
override lazy val numLogicalRecords: Option[Long] =
parsedStatsFields.flatMap(_.numLogicalRecords)

/** Returns the number of records marked as deleted. */
@JsonIgnore
def numDeletedRecords: Long = if (deletionVector != null) deletionVector.cardinality else 0L

/** Returns the total number of records, including those marked as deleted. */
@JsonIgnore
def numPhysicalRecords: Option[Long] = numLogicalRecords.map(_ + numDeletedRecords)
}

object AddFile {
Expand Down Expand Up @@ -539,6 +571,22 @@ case class RemoveFile(
@JsonIgnore
var numLogicalRecords: Option[Long] = None

/**
* Return the unique id of the deletion vector, if present, or `None` if there's no DV.
*
* The unique id differentiates DVs, even if there are multiple in the same file
* or the DV is stored inline.
*/
@JsonIgnore
def getDeletionVectorUniqueId: Option[String] = Option(deletionVector).map(_.uniqueId)

/** Returns the number of records marked as deleted. */
@JsonIgnore
def numDeletedRecords: Long = if (deletionVector != null) deletionVector.cardinality else 0L

/** Returns the total number of records, including those marked as deleted. */
@JsonIgnore
def numPhysicalRecords: Option[Long] = numLogicalRecords.map(_ + numDeletedRecords)

/**
* Create a copy with the new tag. `extendedFileMetadata` is copied unchanged.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ private[delta] object DataSkippingReader {
val sizeCollectorInputEncoders: Seq[Option[ExpressionEncoder[_]]] = Seq(
Option(ExpressionEncoder[Boolean]()),
Option(ExpressionEncoder[java.lang.Long]()),
Option(ExpressionEncoder[java.lang.Long]()),
Option(ExpressionEncoder[java.lang.Long]()))
}

Expand Down Expand Up @@ -444,6 +445,8 @@ trait DataSkippingReaderBase
constructDataFilters(And(Not(e1), Not(e2)))

// Match any file whose null count is larger than zero.
// Note DVs might result in a redundant read of a file.
// However, they cannot lead to a correctness issue.
case IsNull(SkippingEligibleColumn(a, _)) =>
statsProvider.getPredicateWithStatType(a, NULL_COUNT) { nullCount =>
nullCount > Literal(0L)
Expand All @@ -452,6 +455,7 @@ trait DataSkippingReaderBase
constructDataFilters(IsNotNull(e))

// Match any file whose null count is less than the row count.
// Note When comparing numRecords to nullCount we should NOT take into account DV cardinality
case IsNotNull(SkippingEligibleColumn(a, _)) =>
val nullCountCol = StatsColumn(NULL_COUNT, a)
val numRecordsCol = StatsColumn(NUM_RECORDS)
Expand Down Expand Up @@ -677,6 +681,9 @@ trait DataSkippingReaderBase
// caller will negate the expression we return. In case a stats column is NULL, `NOT(expr)`
// must return `TRUE`, and without these NULL checks it would instead return
// `NOT(NULL)` => `NULL`.
// NOTE: Here we only verify the existence of statistics. Therefore, DVs do not
// cause any issue. Furthermore, the check below NUM_RECORDS === NULL_COUNT should NOT
// take into the DV cardinality.
referencedStats.flatMap { stat => stat match {
case StatsColumn(MIN, _) | StatsColumn(MAX, _) =>
Seq(stat, StatsColumn(NULL_COUNT, stat.pathToColumn), StatsColumn(NUM_RECORDS))
Expand Down Expand Up @@ -704,22 +711,25 @@ trait DataSkippingReaderBase
private def buildSizeCollectorFilter(): (ArrayAccumulator, Column => Column) = {
val bytesCompressed = col("size")
val rows = getStatsColumnOrNullLiteral(NUM_RECORDS)
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0L))
val logicalRows = rows - dvCardinality as "logicalRows"

val accumulator = new ArrayAccumulator(
3
)
val accumulator = new ArrayAccumulator(4)

spark.sparkContext.register(accumulator)

// The arguments (order and datatype) must match the encoders defined in the
// `sizeCollectorInputEncoders` value.
val collector = (include: Boolean,
bytesCompressed: java.lang.Long,
logicalRows: java.lang.Long,
rows: java.lang.Long) => {
if (include) {
accumulator.add((0, bytesCompressed)) /* count bytes of AddFiles */
accumulator.add((1, Option(rows).map(_.toLong).getOrElse(-1L))) /* count rows in AddFiles */
accumulator.add((2, 1)) /* count number of AddFiles */
accumulator.add((3, Option(logicalRows)
.map(_.toLong).getOrElse(-1L))) /* count logical rows in AddFiles */
}
include
}
Expand All @@ -729,7 +739,7 @@ trait DataSkippingReaderBase
inputEncoders = sizeCollectorInputEncoders,
deterministic = false)

(accumulator, collectorUdf(_: Column, bytesCompressed, rows))
(accumulator, collectorUdf(_: Column, bytesCompressed, logicalRows, rows))
}

override def filesWithStatsForScan(partitionFilters: Seq[Expression]): DataFrame = {
Expand Down Expand Up @@ -841,6 +851,9 @@ trait DataSkippingReaderBase
/**
* Gathers files that should be included in a scan based on the given predicates.
* Statistics about the amount of data that will be read are gathered and returned.
* Note, the statistics column that is added when keepNumRecords = true should NOT
* take into account DVs. Consumers of this method might commit the file. The semantics
* of the statistics need to be consistent across all files.
*/
override def filesForScan(filters: Seq[Expression], keepNumRecords: Boolean): DeltaScan = {
val startTime = System.currentTimeMillis()
Expand Down Expand Up @@ -994,13 +1007,15 @@ trait DataSkippingReaderBase
val totalDataSize = new DataSize(
sizeInBytesIfKnown,
None,
numOfFilesIfKnown
numOfFilesIfKnown,
None
)

val scannedDataSize = new DataSize(
scan.byteSize,
scan.numPhysicalRecords,
Some(scan.files.size)
Some(scan.files.size),
scan.numLogicalRecords
)

DeltaScan(
Expand Down Expand Up @@ -1039,7 +1054,8 @@ trait DataSkippingReaderBase
"Delta", "DataSkippingReaderEdge.getFilesAndNumRecords") {
import org.apache.spark.sql.delta.implicits._

val numLogicalRecords = col("stats.numRecords")
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0L))
val numLogicalRecords = col("stats.numRecords") - dvCardinality

val result = df.withColumn("numPhysicalRecords", col("stats.numRecords")) // Physical
.withColumn("numLogicalRecords", numLogicalRecords) // Logical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ case class DataSize(
@JsonDeserialize(contentAs = classOf[java.lang.Long])
rows: Option[Long] = None,
@JsonDeserialize(contentAs = classOf[java.lang.Long])
files: Option[Long] = None
files: Option[Long] = None,
@JsonDeserialize(contentAs = classOf[java.lang.Long])
logicalRows: Option[Long] = None
)

object DataSize {
def apply(a: ArrayAccumulator): DataSize = {
DataSize(
Option(a.value(0)).filterNot(_ == -1),
Option(a.value(1)).filterNot(_ == -1),
Option(a.value(2)).filterNot(_ == -1)
Option(a.value(2)).filterNot(_ == -1),
Option(a.value(3)).filterNot(_ == -1)
)
}
}
Expand Down
Loading

0 comments on commit 52c221a

Please sign in to comment.