Skip to content

Commit

Permalink
[SPARK-26965][SQL] Makes ElementAt nullability more precise for array…
Browse files Browse the repository at this point in the history
… cases

## What changes were proposed in this pull request?
In master, `ElementAt` nullable is always true;
https://github.com/apache/spark/blob/be1cadf16dc70e22eae144b3dfce9e269ef95acc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L1977

But, If input is an array and foldable, we could make its nullability more precise.
This fix is based on  SPARK-26637(#23566).

## How was this patch tested?
Added tests in `CollectionExpressionsSuite`.

Closes #23867 from maropu/SPARK-26965.

Authored-by: Takeshi Yamamuro <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
maropu authored and cloud-fan committed Mar 4, 2019
1 parent ad4823c commit 68fbbbe
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1929,7 +1929,8 @@ case class ArrayPosition(left: Expression, right: Expression)
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
case class ElementAt(left: Expression, right: Expression)
extends GetMapValueUtil with GetArrayItemUtil {

@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

Expand Down Expand Up @@ -1974,7 +1975,10 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
}
}

override def nullable: Boolean = true
override def nullable: Boolean = left.dataType match {
case _: ArrayType => computeNullabilityFromArray(left, right)
case _: MapType => true
}

override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ case class GetArrayStructFields(
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {
extends BinaryExpression with GetArrayItemUtil with ExpectsInputTypes with ExtractValue
with NullIntolerant {

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
Expand All @@ -231,23 +232,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)

override def left: Expression = child
override def right: Expression = ordinal

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}

override def nullable: Boolean = computeNullabilityFromArray(left, right)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType

protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
Expand Down Expand Up @@ -281,10 +266,34 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}

/**
* Common base class for [[GetMapValue]] and [[ElementAt]].
* Common trait for [[GetArrayItem]] and [[ElementAt]].
*/
trait GetArrayItemUtil {

/** `Null` is returned for invalid ordinals. */
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
case CreateArray(ar) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
case GetArrayStructFields(CreateArray(elements), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
true
}
} else {
true
}
}
}

/**
* Common trait for [[GetMapValue]] and [[ElementAt]].
*/
trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {

abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
Expand Down Expand Up @@ -380,23 +389,14 @@ case class GetMapValue(child: Expression, key: Expression)
override def left: Expression = child
override def right: Expression = key

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = if (key.foldable && !key.nullable) {
val keyObj = key.eval()
child match {
case m: CreateMap if m.resolved =>
m.keys.zip(m.values).filter { case (k, _) => k.foldable && !k.nullable }.find {
case (k, _) if k.eval() == keyObj => true
case _ => false
}.map(_._2.nullable).getOrElse(true)
case _ =>
true
}
} else {
true
}


/**
* `Null` is returned for invalid ordinals.
*
* TODO: We could make nullability more precise in foldable cases (e.g., literal input).
* But, since the key search is O(n), it takes much time to compute nullability.
* If we find efficient key searches, revisit this.
*/
override def nullable: Boolean = true
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

// todo: current search is O(n), improve it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,39 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
}

test("correctly handles ElementAt nullability for arrays") {
// CreateArray case
val a = AttributeReference("a", IntegerType, nullable = false)()
val b = AttributeReference("b", IntegerType, nullable = true)()
val array = CreateArray(a :: b :: Nil)
assert(!ElementAt(array, Literal(0)).nullable)
assert(ElementAt(array, Literal(1)).nullable)
assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable)
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable)

// GetArrayStructFields case
val f1 = StructField("a", IntegerType, nullable = false)
val f2 = StructField("b", IntegerType, nullable = true)
val structType = StructType(f1 :: f2 :: Nil)
val c = AttributeReference("c", structType, nullable = false)()
val inputArray1 = CreateArray(c :: Nil)
val inputArray1ContainsNull = c.nullable
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull)
assert(!ElementAt(stArray1, Literal(0)).nullable)
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull)
assert(ElementAt(stArray2, Literal(0)).nullable)

val d = AttributeReference("d", structType, nullable = true)()
val inputArray2 = CreateArray(c :: d :: Nil)
val inputArray2ContainsNull = c.nullable || d.nullable
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull)
assert(!ElementAt(stArray3, Literal(0)).nullable)
assert(ElementAt(stArray3, Literal(1)).nullable)
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull)
assert(ElementAt(stArray4, Literal(0)).nullable)
assert(ElementAt(stArray4, Literal(1)).nullable)
}

test("Concat") {
// Primitive-type elements
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c"))
}

test("SPARK-26747 handles GetMapValue nullability correctly when input key is foldable") {
// String key test
val k1 = Literal("k1")
val v1 = AttributeReference("v1", StringType, nullable = true)()
val k2 = Literal("k2")
val v2 = AttributeReference("v2", StringType, nullable = false)()
val map1 = CreateMap(k1 :: v1 :: k2 :: v2 :: Nil)
assert(GetMapValue(map1, Literal("k1")).nullable)
assert(!GetMapValue(map1, Literal("k2")).nullable)
assert(GetMapValue(map1, Literal("non-existent-key")).nullable)

// Complex type key test
val k3 = Literal.create((1, "a"))
val k4 = Literal.create((2, "b"))
val map2 = CreateMap(k3 :: v1 :: k4 :: v2 :: Nil)
assert(GetMapValue(map2, Literal.create((1, "a"))).nullable)
assert(!GetMapValue(map2, Literal.create((2, "b"))).nullable)
}

test("GetStructField") {
val typeS = StructType(StructField("a", IntegerType) :: Nil)
val struct = Literal.create(create_row(1), typeS)
Expand Down

0 comments on commit 68fbbbe

Please sign in to comment.