Skip to content

Commit

Permalink
[SPARK-38030][SQL][3.1] Canonicalization should not remove nullabilit…
Browse files Browse the repository at this point in the history
…y of AttributeReference dataType

This is a backport of #35332 to branch 3.1

### What changes were proposed in this pull request?
Canonicalization of AttributeReference should not remove nullability information of its dataType.

### Why are the changes needed?
SPARK-38030 lists an issue where canonicalization of cast resulted in an unresolved expression, thus causing query failure. The issue was that the child AttributeReference's dataType was converted to nullable during canonicalization and hence the Cast's `checkInputDataTypes` fails.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added unit test to ensure that canonicalization preserves nullability of AttributeReference and does not result in an unresolved cast. Also added a test to ensure that the issue observed in SPARK-38030 (interaction of this bug with AQE) is fixed. This test/repro only works on 3.1 because the code which triggers access on an unresolved object is [lazy](https://github.com/apache/spark/blob/7e5c3b216431b6a5e9a0786bf7cded694228cdee/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala#L132) in 3.2+ and hence does not trigger the issue in 3.2+.

Closes #35444 from shardulm94/SPARK-38030-3.1.

Authored-by: Shardul Mahadik <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
shardulm94 authored and dongjoon-hyun committed Feb 9, 2022
1 parent 1aa9ef0 commit 66e73c4
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ package org.apache.spark.sql.catalyst.expressions
* return the same answer given any input (i.e. false negatives are possible).
*
* The following rules are applied:
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
* - Names for [[GetStructField]] are stripped.
* - Names for [[org.apache.spark.sql.types.DataType]]s and [[GetStructField]] are stripped.
* - TimeZoneId for [[Cast]] and [[AnsiCast]] are stripped if `needsTimeZone` is false.
* - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
* by `hashCode`.
Expand All @@ -39,10 +38,10 @@ object Canonicalize {
expressionReorder(ignoreTimeZone(ignoreNamesTypes(e)))
}

/** Remove names and nullability from types, and names from `GetStructField`. */
/** Remove names from types and `GetStructField`. */
private[expressions] def ignoreNamesTypes(e: Expression): Expression = e match {
case a: AttributeReference =>
AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId)
AttributeReference("none", a.dataType)(exprId = a.exprId)
case GetStructField(child, ordinal, Some(_)) => GetStructField(child, ordinal, None)
case _ => e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}

class CanonicalizeSuite extends SparkFunSuite {

Expand Down Expand Up @@ -170,4 +170,17 @@ class CanonicalizeSuite extends SparkFunSuite {
assert(nestedExpr2.canonicalized != nestedExpr3.canonicalized)
}
}

test("SPARK-38030: Canonicalization should not remove nullability of AttributeReference" +
" dataType") {
val structType = StructType(Seq(StructField("name", StringType, nullable = false)))
val attr = AttributeReference("col", structType)()
// AttributeReference dataType should not be converted to nullable
assert(attr.canonicalized.dataType === structType)

val cast = Cast(attr, structType)
assert(cast.resolved)
// canonicalization should not converted resolved cast to unresolved
assert(cast.canonicalized.resolved)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.scalatest.PrivateMethodTester

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, ConvertToLocalRelation, PropagateEmptyRelation}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
Expand All @@ -39,7 +39,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -1502,4 +1502,24 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-38030: Query with cast containing non-nullable columns should succeed with AQE") {
import scala.collection.JavaConverters._
withSQLConf(
// disable some optimizer rules which prevent repro with an empty DataFrame
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
(ConvertToLocalRelation.ruleName + "," + PropagateEmptyRelation.ruleName),
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val nameType = StructType(Seq(StructField("firstName", StringType, nullable = false)))
val schema = StructType(Seq(StructField("name", nameType, nullable = false)))
// only change column name so that it is a valid cast
val newNameType = StructType(Seq(StructField("fname", StringType, nullable = false)))

val df = spark.createDataFrame(List.empty[Row].asJava, schema)
val df1 = df.withColumn("newName", 'name.cast(newNameType))
// required to trigger the issue observed in SPARK-38030
val df2 = df1.union(df1).repartition(1)
assert(df2.collect().length == 0)
}
}
}

0 comments on commit 66e73c4

Please sign in to comment.