Skip to content

Commit

Permalink
[Spark] Simplify reference binding in DeltaInvariantChecker (#4209)
Browse files Browse the repository at this point in the history
<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [x] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description

This PR simplifies the binding of attribute references in
`DeltaInvariantCheckerExec`. Before this PR we had some code to manually
bind these references. This PR changes this to make Spark do the binding
for us.

## How was this patch tested?

Existing tests.

## Does this PR introduce _any_ user-facing changes?

<!--
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to
show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change
compared to the released Delta Lake versions or within the unreleased
branches such as master.
If no, write 'No'.
-->
  • Loading branch information
tomvanbussel authored Mar 3, 2025
1 parent 6dec3d7 commit 08f3a79
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.delta.constraints.Constraints.{Check, NotNull}
import org.apache.spark.sql.delta.schema.DeltaInvariantViolationException

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, BindReferences, Expression, NonSQLExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, NullType}
Expand All @@ -34,23 +34,15 @@ import org.apache.spark.sql.types.{DataType, NullType}
*/
case class CheckDeltaInvariant(
child: Expression,
columnExtractors: Map[String, Expression],
columnExtractors: Seq[(String, Expression)],
constraint: Constraint)
extends UnaryExpression with NonSQLExpression with CodegenFallback {
extends Expression with NonSQLExpression with CodegenFallback {

override def children: Seq[Expression] = child +: columnExtractors.map(_._2)
override def dataType: DataType = NullType
override def foldable: Boolean = false
override def nullable: Boolean = true

def withBoundReferences(input: AttributeSeq): CheckDeltaInvariant = {
CheckDeltaInvariant(
BindReferences.bindReference(child, input),
columnExtractors.map {
case (column, extractor) => column -> BindReferences.bindReference(extractor, input)
},
constraint)
}

private def assertRule(input: InternalRow): Unit = constraint match {
case n: NotNull =>
if (child.eval(input) == null) {
Expand All @@ -59,7 +51,12 @@ case class CheckDeltaInvariant(
case c: Check =>
val result = child.eval(input)
if (result == null || result == false) {
throw DeltaInvariantViolationException(c, columnExtractors.mapValues(_.eval(input)).toMap)
throw DeltaInvariantViolationException(
c,
columnExtractors.map {
case (column, extractor) => column -> extractor.eval(input)
}.toMap
)
}
}

Expand Down Expand Up @@ -111,8 +108,7 @@ case class CheckDeltaInvariant(
}.fold(start)(_ + _)
}

private def generateExpressionValidationCode(
constraintName: String, expr: Expression, ctx: CodegenContext): Block = {
private def generateExpressionValidationCode(ctx: CodegenContext): Block = {
val elementValue = child.genCode(ctx)
val invariantField = ctx.addReferenceObj("errMsg", constraint)
val colListName = ctx.freshName("colList")
Expand All @@ -129,12 +125,17 @@ case class CheckDeltaInvariant(

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val code = constraint match {
case NotNull(_) => generateNotNullCode(ctx)
case Check(name, expr) => generateExpressionValidationCode(name, expr, ctx)
case _: NotNull => generateNotNullCode(ctx)
case _: Check => generateExpressionValidationCode(ctx)
}
ev.copy(code = code, isNull = TrueLiteral, value = JavaCode.literal("null", NullType))
}

override protected def withNewChildInternal(newChild: Expression): CheckDeltaInvariant =
copy(child = newChild)
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
copy(
child = newChildren.head,
columnExtractors = columnExtractors.map(_._1).zip(newChildren.tail)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.StructType
*/
case class DeltaInvariantChecker(
child: LogicalPlan,
deltaConstraints: Seq[Constraint]) extends UnaryNode {
deltaConstraints: Seq[CheckDeltaInvariant]) extends UnaryNode {
assert(deltaConstraints.nonEmpty)

override def output: Seq[Attribute] = child.output
Expand All @@ -52,6 +52,17 @@ case class DeltaInvariantChecker(
copy(child = newChild)
}

object DeltaInvariantChecker {
def apply(
spark: SparkSession,
child: LogicalPlan,
constraints: Seq[Constraint]): DeltaInvariantChecker = {
val invariantChecks =
DeltaInvariantCheckerExec.buildInvariantChecks(child.output, constraints, spark)
DeltaInvariantChecker(child, invariantChecks)
}
}

object DeltaInvariantCheckerStrategy extends SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case DeltaInvariantChecker(child, constraints) =>
Expand All @@ -66,26 +77,21 @@ object DeltaInvariantCheckerStrategy extends SparkStrategy {
*/
case class DeltaInvariantCheckerExec(
child: SparkPlan,
constraints: Seq[Constraint]) extends UnaryExecNode {
constraints: Seq[CheckDeltaInvariant]) extends UnaryExecNode {

override def output: Seq[Attribute] = child.output

override protected def doExecute(): RDD[InternalRow] = {
if (constraints.isEmpty) return child.execute()
val invariantChecks =
DeltaInvariantCheckerExec.buildInvariantChecks(child.output, constraints, session)

// Resolve current_date()/current_time() expressions.
// We resolve currentTime for all invariants together to make sure we use the same timestamp.
val invariantsFakePlan = AnalysisHelper.FakeLogicalPlan(invariantChecks, Nil)
val invariantsFakePlan = AnalysisHelper.FakeLogicalPlan(constraints, Nil)
val newInvariantsPlan = optimizer.ComputeCurrentTime(invariantsFakePlan)
val localOutput = child.output
val constraintsWithFixedTime = newInvariantsPlan.expressions.toArray

child.execute().mapPartitionsInternal { rows =>
val boundRefs = newInvariantsPlan.expressions
.asInstanceOf[Seq[CheckDeltaInvariant]]
.map(_.withBoundReferences(localOutput))
val assertions = UnsafeProjection.create(boundRefs)
val assertions = UnsafeProjection.create(constraintsWithFixedTime, child.output)
rows.map { row =>
assertions(row)
row
Expand All @@ -102,6 +108,14 @@ case class DeltaInvariantCheckerExec(
}

object DeltaInvariantCheckerExec extends DeltaLogging {
def apply(
spark: SparkSession,
child: SparkPlan,
constraints: Seq[Constraint]): DeltaInvariantCheckerExec = {
val invariantChecks =
DeltaInvariantCheckerExec.buildInvariantChecks(child.output, constraints, spark)
DeltaInvariantCheckerExec(child, invariantChecks)
}

// Specialized optimizer to run necessary rules so that the check expressions can be evaluated.
object DeltaInvariantCheckerOptimizer
Expand Down Expand Up @@ -201,7 +215,7 @@ object DeltaInvariantCheckerExec extends DeltaLogging {
resolvedExpr
}

CheckDeltaInvariant(executableExpr, columnExtractors.toMap, constraint)
CheckDeltaInvariant(executableExpr, columnExtractors.toSeq, constraint)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl

val empty2NullPlan = convertEmptyToNullIfNeeded(queryExecution.executedPlan,
partitioningColumns, constraints)
val checkInvariants = DeltaInvariantCheckerExec(empty2NullPlan, constraints)
val checkInvariants = DeltaInvariantCheckerExec(spark, empty2NullPlan, constraints)
// No need to plan optimized write if the write command is OPTIMIZE, which aims to produce
// evenly-balanced data files already.
val physicalPlan = if (!isOptimize &&
Expand Down

0 comments on commit 08f3a79

Please sign in to comment.