From ded67f55c88425e27ee5e73b5f62e6b773d610e9 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 17 May 2018 13:45:19 +0200 Subject: [PATCH 01/10] [SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in new collection expressions. --- .../expressions/collectionOperations.scala | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2a4e42d4ba316..f08baeda52794 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -388,7 +388,8 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient + private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(input: Any): Any = input match { case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) @@ -552,7 +553,8 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def children: Seq[Expression] = Seq(x, start, length) - lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + @transient + private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] @@ -837,8 +839,6 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) - override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -870,6 +870,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override protected def nullSafeEval(input: Any): Any = { var min: Any = null + val ordering = TypeUtils.getInterpretedOrdering(dataType) input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => if (item != null && (min == null || ordering.lt(item, min))) { min = item @@ -902,8 +903,6 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) - override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -935,6 +934,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override protected def nullSafeEval(input: Any): Any = { var max: Any = null + val ordering = TypeUtils.getInterpretedOrdering(dataType) input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => if (item != null && (max == null || ordering.gt(item, max))) { max = item @@ -1126,9 +1126,9 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti """) case class Concat(children: Seq[Expression]) extends Expression { - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + private def maxArrayLength: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - val allowedTypes = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -1147,6 +1147,7 @@ case class Concat(children: Seq[Expression]) extends Expression { override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + @transient lazy val javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) @@ -1167,9 +1168,9 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val arrayData = inputs.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > maxArrayLength) { throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + s" elements due to exceeding the array size limit $maxArrayLength.") } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 @@ -1227,9 +1228,9 @@ case class Concat(children: Seq[Expression]) extends Expression { |for (int z = 0; z < ${children.length}; z++) { | $numElements += args[z].numElements(); |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > $maxArrayLength) { | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit $maxArrayLength."); |} """.stripMargin @@ -1324,15 +1325,17 @@ case class Concat(children: Seq[Expression]) extends Expression { since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + private def maxArrayLength: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + @transient private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] override def nullable: Boolean = child.nullable || childDataType.containsNull override def dataType: DataType = childDataType.elementType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient + private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -1352,9 +1355,9 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { val arrayData = elements.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > maxArrayLength) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + s"$numberOfElements elements due to exceeding the array size limit $maxArrayLength.") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 @@ -1401,9 +1404,9 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > $MAX_ARRAY_LENGTH) { + |if ($variableName > $maxArrayLength) { | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | $variableName + " elements due to exceeding the array size limit $maxArrayLength."); |} """.stripMargin (code, variableName) @@ -1483,7 +1486,7 @@ case class Flatten(child: Expression) extends UnaryExpression { case class ArrayRepeat(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + private def maxArrayLength: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) @@ -1496,9 +1499,9 @@ case class ArrayRepeat(left: Expression, right: Expression) if (count == null) { null } else { - if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + if (count.asInstanceOf[Int] > maxArrayLength) { throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + - s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + s"due to exceeding the array size limit $maxArrayLength."); } val element = left.eval(input) new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) @@ -1557,9 +1560,9 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > $maxArrayLength) { | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit $maxArrayLength."); |} """.stripMargin From e96962e831a35f61e113e27c38e573a9b4d3ddc2 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 17 May 2018 14:56:33 +0200 Subject: [PATCH 02/10] [SPARK-24305][SQL][FOLLOWUP] Reverting change of making the ordering variable local. --- .../sql/catalyst/expressions/collectionOperations.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index f08baeda52794..5a7d7689f7453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -839,6 +839,9 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + @transient + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -870,7 +873,6 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override protected def nullSafeEval(input: Any): Any = { var min: Any = null - val ordering = TypeUtils.getInterpretedOrdering(dataType) input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => if (item != null && (min == null || ordering.lt(item, min))) { min = item @@ -903,6 +905,9 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + @transient + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -934,7 +939,6 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override protected def nullSafeEval(input: Any): Any = { var max: Any = null - val ordering = TypeUtils.getInterpretedOrdering(dataType) input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => if (item != null && (max == null || ordering.gt(item, max))) { max = item From f6368b5c68b69f9595240bb98cd533f9d7a110e6 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 17 May 2018 17:12:30 +0200 Subject: [PATCH 03/10] [SPARK-24305][SQL][FOLLOWUP] Using ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH directly --- .../expressions/collectionOperations.scala | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5a7d7689f7453..46b185cb49c1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1130,8 +1130,6 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti """) case class Concat(children: Seq[Expression]) extends Expression { - private def maxArrayLength: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { @@ -1172,9 +1170,10 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val arrayData = inputs.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > maxArrayLength) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - s" elements due to exceeding the array size limit $maxArrayLength.") + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 @@ -1232,9 +1231,10 @@ case class Concat(children: Seq[Expression]) extends Expression { |for (int z = 0; z < ${children.length}; z++) { | $numElements += args[z].numElements(); |} - |if ($numElements > $maxArrayLength) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $maxArrayLength."); + | " elements due to exceeding the array size limit " + + | ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -1329,8 +1329,6 @@ case class Concat(children: Seq[Expression]) extends Expression { since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - private def maxArrayLength: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - @transient private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] @@ -1359,9 +1357,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { val arrayData = elements.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) - if (numberOfElements > maxArrayLength) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - s"$numberOfElements elements due to exceeding the array size limit $maxArrayLength.") + s"$numberOfElements elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 @@ -1408,9 +1407,10 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > $maxArrayLength) { + |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit $maxArrayLength."); + | $variableName + " elements due to exceeding the array size limit " + + | ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin (code, variableName) @@ -1490,8 +1490,6 @@ case class Flatten(child: Expression) extends UnaryExpression { case class ArrayRepeat(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { - private def maxArrayLength: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) @@ -1503,9 +1501,9 @@ case class ArrayRepeat(left: Expression, right: Expression) if (count == null) { null } else { - if (count.asInstanceOf[Int] > maxArrayLength) { + if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + - s"due to exceeding the array size limit $maxArrayLength."); + s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); } val element = left.eval(input) new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) @@ -1564,9 +1562,10 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > $maxArrayLength) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit $maxArrayLength."); + | " elements due to exceeding the array size limit " + + | ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin From 2862d3e4ad7c2207f23db2f2d58fb27ba6e708c5 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Thu, 17 May 2018 21:04:19 +0200 Subject: [PATCH 04/10] [SPARK-24305][SQL][FOLLOWUP] Fixing failing tests. --- .../catalyst/expressions/collectionOperations.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 46b185cb49c1d..652e990c30f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1233,8 +1233,8 @@ case class Concat(children: Seq[Expression]) extends Expression { |} |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit " + - | ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -1409,8 +1409,8 @@ case class Flatten(child: Expression) extends UnaryExpression { |} |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit " + - | ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | $variableName + " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin (code, variableName) @@ -1564,8 +1564,8 @@ case class ArrayRepeat(left: Expression, right: Expression) |} |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit " + - | ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin From 62c55ada0e23eb47eb9d3b717f9a9fbc8155a05f Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 16 Jul 2018 17:13:51 +0200 Subject: [PATCH 05/10] [SPARK-24305][SQL][FOLLOWUP] Using def wherever it's possible and covering more expressions. --- .../expressions/collectionOperations.scala | 74 ++++++++----------- 1 file changed, 30 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 94a8fa43f564a..4368f3422f835 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -168,27 +168,21 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) - override def dataType: DataType = ArrayType(mountSchema) - - override def nullable: Boolean = children.exists(_.nullable) - - private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - - private lazy val arrayElementTypes = arrayTypes.map(_.elementType) - - @transient private lazy val mountSchema: StructType = { + @transient override lazy val dataType: DataType = { val fields = children.zip(arrayElementTypes).zipWithIndex.map { case ((expr: NamedExpression, elementType), _) => StructField(expr.name, elementType, nullable = true) case ((_, elementType), idx) => StructField(idx.toString, elementType, nullable = true) } - StructType(fields) + ArrayType(StructType(fields), containsNull = false) } - @transient lazy val numberOfArrays: Int = children.length + override def nullable: Boolean = children.exists(_.nullable) + + private def arrayElementTypes = children.map(_.dataType.asInstanceOf[ArrayType].elementType) - @transient lazy val genericArrayData = classOf[GenericArrayData].getName + private def genericArrayData = classOf[GenericArrayData].getName def emptyInputGenCode(ev: ExprCode): ExprCode = { ev.copy(code""" @@ -256,7 +250,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI ("ArrayData[]", arrVals) :: Nil) val initVariables = s""" - |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |ArrayData[] $arrVals = new ArrayData[${children.length}]; |int $biggestCardinality = 0; |${CodeGenerator.javaType(dataType)} ${ev.value} = null; """.stripMargin @@ -268,7 +262,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |if (!${ev.isNull}) { | Object[] $args = new Object[$biggestCardinality]; | for (int $i = 0; $i < $biggestCardinality; $i ++) { - | Object[] $currentRow = new Object[$numberOfArrays]; + | Object[] $currentRow = new Object[${children.length}]; | $getValueForTypeSplitted | $args[$i] = new $genericInternalRow($currentRow); | } @@ -278,7 +272,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (numberOfArrays == 0) { + if (children.length == 0) { emptyInputGenCode(ev) } else { nonEmptyInputGenCode(ctx, ev) @@ -360,7 +354,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + private def childDataType: MapType = child.dataType.asInstanceOf[MapType] override def dataType: DataType = { ArrayType( @@ -741,14 +735,15 @@ case class MapConcat(children: Seq[Expression]) extends Expression { since = "2.4.0") case class MapFromEntries(child: Expression) extends UnaryExpression { - @transient - private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { - case ArrayType( + @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = { + child.dataType match { + case ArrayType( StructType(Array( - StructField(_, keyType, keyNullable, _), - StructField(_, valueType, valueNullable, _))), + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) - case _ => None + case _ => None + } } private def nullEntries: Boolean = dataTypeDetails.get._3 @@ -953,8 +948,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient - private lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -976,8 +970,7 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient - private lazy val gt: Comparator[Any] = { + @transient private lazy val gt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -1215,8 +1208,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - @transient - private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(input: Any): Any = input match { case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) @@ -1608,8 +1600,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def children: Seq[Expression] = Seq(x, start, length) - @transient - private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] @@ -1895,8 +1886,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - @transient - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1961,8 +1951,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - @transient - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -2236,8 +2225,7 @@ case class Concat(children: Seq[Expression]) extends Expression { override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) - @transient - lazy val javaType: String = CodeGenerator.javaType(dataType) + private def javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) @@ -2416,15 +2404,13 @@ case class Concat(children: Seq[Expression]) extends Expression { since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - @transient - private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] override def nullable: Boolean = child.nullable || childDataType.containsNull override def dataType: DataType = childDataType.elementType - @transient - private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -2607,7 +2593,7 @@ case class Sequence( override def nullable: Boolean = children.exists(_.nullable) - override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) override def checkInputDataTypes(): TypeCheckResult = { val startType = start.dataType @@ -2638,7 +2624,7 @@ case class Sequence( stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), timeZoneId) - private lazy val impl: SequenceImpl = dataType.elementType match { + @transient private lazy val impl: SequenceImpl = dataType.elementType match { case iType: IntegralType => type T = iType.InternalType val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) @@ -3115,7 +3101,7 @@ case class ArrayRemove(left: Expression, right: Expression) Seq(ArrayType, elementType) } - lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -3232,7 +3218,7 @@ case class ArrayDistinct(child: Expression) override def dataType: DataType = child.dataType - @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) From 294ac69e618bb8d8b2f988540338d27534b560e9 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 16 Jul 2018 17:39:51 +0200 Subject: [PATCH 06/10] [SPARK-24305][SQL][FOLLOWUP] Fixing indentation --- .../sql/catalyst/expressions/collectionOperations.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4368f3422f835..ccae512327cb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -738,10 +738,10 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = { child.dataType match { case ArrayType( - StructType(Array( - StructField(_, keyType, keyNullable, _), - StructField(_, valueType, valueNullable, _))), - containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + StructType(Array( + StructField(_, kt, kn, _), + StructField(_, vt, vn, _))), + cn) => Some((MapType(kt, vt, vn), kn, cn)) case _ => None } } From 94b86a2845332b249380319d156834f96488d2fa Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 17 Jul 2018 11:41:17 +0200 Subject: [PATCH 07/10] [SPARK-24305][SQL][FOLLOWUP] Making fields used in eval lazy-evaluated --- .../expressions/collectionOperations.scala | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ccae512327cb2..f19904470c679 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -180,7 +180,9 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI override def nullable: Boolean = children.exists(_.nullable) - private def arrayElementTypes = children.map(_.dataType.asInstanceOf[ArrayType].elementType) + @transient private lazy val arrayElementTypes = { + children.map(_.dataType.asInstanceOf[ArrayType].elementType) + } private def genericArrayData = classOf[GenericArrayData].getName @@ -746,11 +748,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { } } - private def nullEntries: Boolean = dataTypeDetails.get._3 + @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 override def nullable: Boolean = child.nullable || nullEntries - override def dataType: MapType = dataTypeDetails.get._1 + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some(_) => TypeCheckResult.TypeCheckSuccess @@ -992,7 +994,10 @@ trait ArraySortLike extends ExpectsInputTypes { } } - def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + @transient lazy val elementType: DataType = { + arrayExpression.dataType.asInstanceOf[ArrayType].elementType + } + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull def sortEval(array: Any, ascending: Boolean): Any = { @@ -1208,7 +1213,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(input: Any): Any = input match { case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) @@ -1600,7 +1605,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def children: Seq[Expression] = Seq(x, start, length) - private def elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] @@ -1927,7 +1932,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast min } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -1992,7 +1997,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast max } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -2097,7 +2102,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) - override def dataType: DataType = left.dataType match { + @transient override lazy val dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType } @@ -2408,9 +2413,9 @@ case class Flatten(child: Expression) extends UnaryExpression { override def nullable: Boolean = child.nullable || childDataType.containsNull - override def dataType: DataType = childDataType.elementType + @transient override lazy val dataType: DataType = childDataType.elementType - private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -3218,7 +3223,7 @@ case class ArrayDistinct(child: Expression) override def dataType: DataType = child.dataType - private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) From 872ef99987de0ed25c2e78a5347a2b18be1adc4b Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 17 Jul 2018 11:50:18 +0200 Subject: [PATCH 08/10] [SPARK-24305][SQL][FOLLOWUP] Making fields used in eval lazy-evaluated (Concat, MapConcat) --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2345037ee7376..2b61261fb3e73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -516,7 +516,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres } } - override def dataType: MapType = { + @transient override lazy val dataType: MapType = { if (children.isEmpty) { MapType(StringType, StringType) } else { @@ -2224,7 +2224,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } - override def dataType: DataType = { + @transient override lazy val dataType: DataType = { if (children.isEmpty) { StringType } else { From fd3a94524b3916507381fd951d65d0a90afad38e Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 17 Jul 2018 15:06:14 +0200 Subject: [PATCH 09/10] [SPARK-24305][SQL][FOLLOWUP] Addressing review comments --- .../expressions/collectionOperations.scala | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2b61261fb3e73..66f7730932d0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -180,9 +180,8 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI override def nullable: Boolean = children.exists(_.nullable) - @transient private lazy val arrayElementTypes = { + @transient private lazy val arrayElementTypes = children.map(_.dataType.asInstanceOf[ArrayType].elementType) - } private def genericArrayData = classOf[GenericArrayData].getName @@ -356,7 +355,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - private def childDataType: MapType = child.dataType.asInstanceOf[MapType] + @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] override def dataType: DataType = { ArrayType( @@ -733,15 +732,14 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres since = "2.4.0") case class MapFromEntries(child: Expression) extends UnaryExpression { - @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = { - child.dataType match { - case ArrayType( - StructType(Array( - StructField(_, kt, kn, _), - StructField(_, vt, vn, _))), - cn) => Some((MapType(kt, vt, vn), kn, cn)) - case _ => None - } + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None } @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 @@ -990,9 +988,8 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient lazy val elementType: DataType = { + @transient lazy val elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType - } def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull From 922d2f081f034295a55ab9a0c454723b178057fc Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 17 Jul 2018 15:35:51 +0200 Subject: [PATCH 10/10] [SPARK-24305][SQL][FOLLOWUP] A small optimization of Slice and ElementAt expression --- .../expressions/collectionOperations.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 66f7730932d0b..d60f4c36fa214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1596,7 +1596,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) - override def children: Seq[Expression] = Seq(x, start, length) + @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType @@ -2092,8 +2092,11 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + + @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull + + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) @transient override lazy val dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType @@ -2104,7 +2107,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti Seq(TypeCollection(ArrayType, MapType), left.dataType match { case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _: MapType => mapKeyType case _ => AnyDataType // no match for a wrong 'left' expression type } ) @@ -2114,8 +2117,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => - TypeUtils.checkForOrderingExpr( - left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess } } @@ -2137,14 +2139,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } else { array.numElements() + index } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + if (arrayContainsNull && array.isNullAt(idx)) { null } else { array.get(idx, dataType) } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) + getValueEval(value, ordinal, mapKeyType, ordering) } } @@ -2153,7 +2155,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + val nullCheck = if (arrayContainsNull) { s""" |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true;