Skip to content

Commit

Permalink
[SPARK-24305][SQL][FOLLOWUP] A small optimization of Slice and Elemen…
Browse files Browse the repository at this point in the history
…tAt expression
  • Loading branch information
mn-mikke committed Jul 17, 2018
1 parent fd3a945 commit 922d2f0
Showing 1 changed file with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1596,7 +1596,7 @@ case class Slice(x: Expression, start: Expression, length: Expression)

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)

override def children: Seq[Expression] = Seq(x, start, length)
@transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval

@transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType

Expand Down Expand Up @@ -2092,8 +2092,11 @@ case class ArrayPosition(left: Expression, right: Expression)
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType

@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull

@transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)

@transient override lazy val dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
Expand All @@ -2104,7 +2107,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
Seq(TypeCollection(ArrayType, MapType),
left.dataType match {
case _: ArrayType => IntegerType
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
case _: MapType => mapKeyType
case _ => AnyDataType // no match for a wrong 'left' expression type
}
)
Expand All @@ -2114,8 +2117,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
TypeUtils.checkForOrderingExpr(
left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName")
case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
}
}
Expand All @@ -2137,14 +2139,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
} else {
array.numElements() + index
}
if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
if (arrayContainsNull && array.isNullAt(idx)) {
null
} else {
array.get(idx, dataType)
}
}
case _: MapType =>
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
getValueEval(value, ordinal, mapKeyType, ordering)
}
}

Expand All @@ -2153,7 +2155,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
case _: ArrayType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("elementAtIndex")
val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
val nullCheck = if (arrayContainsNull) {
s"""
|if ($eval1.isNullAt($index)) {
| ${ev.isNull} = true;
Expand Down

0 comments on commit 922d2f0

Please sign in to comment.