Skip to content

Commit

Permalink
[Spark] Simplify reference binding in DeltaInvariantChecker
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvanbussel committed Mar 3, 2025
1 parent fd6f7cd commit 02a80af
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.delta.constraints.Constraints.{Check, NotNull}
import org.apache.spark.sql.delta.schema.DeltaInvariantViolationException

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, BindReferences, Expression, NonSQLExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, NullType}
Expand All @@ -34,23 +34,15 @@ import org.apache.spark.sql.types.{DataType, NullType}
*/
case class CheckDeltaInvariant(
child: Expression,
columnExtractors: Map[String, Expression],
columnExtractors: Seq[(String, Expression)],
constraint: Constraint)
extends UnaryExpression with NonSQLExpression with CodegenFallback {
extends Expression with NonSQLExpression with CodegenFallback {

override def children: Seq[Expression] = child +: columnExtractors.map(_._2)
override def dataType: DataType = NullType
override def foldable: Boolean = false
override def nullable: Boolean = true

def withBoundReferences(input: AttributeSeq): CheckDeltaInvariant = {
CheckDeltaInvariant(
BindReferences.bindReference(child, input),
columnExtractors.map {
case (column, extractor) => column -> BindReferences.bindReference(extractor, input)
},
constraint)
}

private def assertRule(input: InternalRow): Unit = constraint match {
case n: NotNull =>
if (child.eval(input) == null) {
Expand All @@ -59,7 +51,12 @@ case class CheckDeltaInvariant(
case c: Check =>
val result = child.eval(input)
if (result == null || result == false) {
throw DeltaInvariantViolationException(c, columnExtractors.mapValues(_.eval(input)).toMap)
throw DeltaInvariantViolationException(
c,
columnExtractors.map {
case (column, extractor) => column -> extractor.eval(input)
}.toMap
)
}
}

Expand Down Expand Up @@ -111,8 +108,7 @@ case class CheckDeltaInvariant(
}.fold(start)(_ + _)
}

private def generateExpressionValidationCode(
constraintName: String, expr: Expression, ctx: CodegenContext): Block = {
private def generateExpressionValidationCode(ctx: CodegenContext): Block = {
val elementValue = child.genCode(ctx)
val invariantField = ctx.addReferenceObj("errMsg", constraint)
val colListName = ctx.freshName("colList")
Expand All @@ -129,12 +125,17 @@ case class CheckDeltaInvariant(

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val code = constraint match {
case NotNull(_) => generateNotNullCode(ctx)
case Check(name, expr) => generateExpressionValidationCode(name, expr, ctx)
case _: NotNull => generateNotNullCode(ctx)
case _: Check => generateExpressionValidationCode(ctx)
}
ev.copy(code = code, isNull = TrueLiteral, value = JavaCode.literal("null", NullType))
}

override protected def withNewChildInternal(newChild: Expression): CheckDeltaInvariant =
copy(child = newChild)
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
copy(
child = newChildren.head,
columnExtractors = columnExtractors.map(_._1).zip(newChildren.tail)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.StructType
*/
case class DeltaInvariantChecker(
child: LogicalPlan,
deltaConstraints: Seq[Constraint]) extends UnaryNode {
deltaConstraints: Seq[CheckDeltaInvariant]) extends UnaryNode {
assert(deltaConstraints.nonEmpty)

override def output: Seq[Attribute] = child.output
Expand All @@ -52,6 +52,17 @@ case class DeltaInvariantChecker(
copy(child = newChild)
}

object DeltaInvariantChecker {
def apply(
spark: SparkSession,
child: LogicalPlan,
constraints: Seq[Constraint]): DeltaInvariantChecker = {
val invariantChecks =
DeltaInvariantCheckerExec.buildInvariantChecks(child.output, constraints, spark)
DeltaInvariantChecker(child, invariantChecks)
}
}

object DeltaInvariantCheckerStrategy extends SparkStrategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case DeltaInvariantChecker(child, constraints) =>
Expand All @@ -66,26 +77,15 @@ object DeltaInvariantCheckerStrategy extends SparkStrategy {
*/
case class DeltaInvariantCheckerExec(
child: SparkPlan,
constraints: Seq[Constraint]) extends UnaryExecNode {
constraints: Seq[CheckDeltaInvariant]) extends UnaryExecNode {

override def output: Seq[Attribute] = child.output

override protected def doExecute(): RDD[InternalRow] = {
if (constraints.isEmpty) return child.execute()
val invariantChecks =
DeltaInvariantCheckerExec.buildInvariantChecks(child.output, constraints, session)

// Resolve current_date()/current_time() expressions.
// We resolve currentTime for all invariants together to make sure we use the same timestamp.
val invariantsFakePlan = AnalysisHelper.FakeLogicalPlan(invariantChecks, Nil)
val newInvariantsPlan = optimizer.ComputeCurrentTime(invariantsFakePlan)
val localOutput = child.output

child.execute().mapPartitionsInternal { rows =>
val boundRefs = newInvariantsPlan.expressions
.asInstanceOf[Seq[CheckDeltaInvariant]]
.map(_.withBoundReferences(localOutput))
val assertions = UnsafeProjection.create(boundRefs)
val assertions = UnsafeProjection.create(constraints, child.output)
rows.map { row =>
assertions(row)
row
Expand All @@ -102,6 +102,14 @@ case class DeltaInvariantCheckerExec(
}

object DeltaInvariantCheckerExec extends DeltaLogging {
def apply(
spark: SparkSession,
child: SparkPlan,
constraints: Seq[Constraint]): DeltaInvariantCheckerExec = {
val invariantChecks =
DeltaInvariantCheckerExec.buildInvariantChecks(child.output, constraints, spark)
DeltaInvariantCheckerExec(child, invariantChecks)
}

// Specialized optimizer to run necessary rules so that the check expressions can be evaluated.
object DeltaInvariantCheckerOptimizer
Expand Down Expand Up @@ -201,7 +209,7 @@ object DeltaInvariantCheckerExec extends DeltaLogging {
resolvedExpr
}

CheckDeltaInvariant(executableExpr, columnExtractors.toMap, constraint)
CheckDeltaInvariant(executableExpr, columnExtractors.toSeq, constraint)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl

val empty2NullPlan = convertEmptyToNullIfNeeded(queryExecution.executedPlan,
partitioningColumns, constraints)
val checkInvariants = DeltaInvariantCheckerExec(empty2NullPlan, constraints)
val checkInvariants = DeltaInvariantCheckerExec(spark, empty2NullPlan, constraints)
// No need to plan optimized write if the write command is OPTIMIZE, which aims to produce
// evenly-balanced data files already.
val physicalPlan = if (!isOptimize &&
Expand Down

0 comments on commit 02a80af

Please sign in to comment.