diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 972bc6e57892c..d60f4c36fa214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -168,27 +168,22 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) - override def dataType: DataType = ArrayType(mountSchema) - - override def nullable: Boolean = children.exists(_.nullable) - - private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - - private lazy val arrayElementTypes = arrayTypes.map(_.elementType) - - @transient private lazy val mountSchema: StructType = { + @transient override lazy val dataType: DataType = { val fields = children.zip(arrayElementTypes).zipWithIndex.map { case ((expr: NamedExpression, elementType), _) => StructField(expr.name, elementType, nullable = true) case ((_, elementType), idx) => StructField(idx.toString, elementType, nullable = true) } - StructType(fields) + ArrayType(StructType(fields), containsNull = false) } - @transient lazy val numberOfArrays: Int = children.length + override def nullable: Boolean = children.exists(_.nullable) + + @transient private lazy val arrayElementTypes = + children.map(_.dataType.asInstanceOf[ArrayType].elementType) - @transient lazy val genericArrayData = classOf[GenericArrayData].getName + private def genericArrayData = classOf[GenericArrayData].getName def emptyInputGenCode(ev: ExprCode): ExprCode = { ev.copy(code""" @@ -256,7 +251,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI ("ArrayData[]", arrVals) :: Nil) val initVariables = s""" - |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |ArrayData[] $arrVals = new ArrayData[${children.length}]; |int $biggestCardinality = 0; |${CodeGenerator.javaType(dataType)} ${ev.value} = null; """.stripMargin @@ -268,7 +263,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |if (!${ev.isNull}) { | Object[] $args = new Object[$biggestCardinality]; | for (int $i = 0; $i < $biggestCardinality; $i ++) { - | Object[] $currentRow = new Object[$numberOfArrays]; + | Object[] $currentRow = new Object[${children.length}]; | $getValueForTypeSplitted | $args[$i] = new $genericInternalRow($currentRow); | } @@ -278,7 +273,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (numberOfArrays == 0) { + if (children.length == 0) { emptyInputGenCode(ev) } else { nonEmptyInputGenCode(ctx, ev) @@ -360,7 +355,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] override def dataType: DataType = { ArrayType( @@ -520,7 +515,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres } } - override def dataType: MapType = { + @transient override lazy val dataType: MapType = { if (children.isEmpty) { MapType(StringType, StringType) } else { @@ -747,11 +742,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { case _ => None } - private def nullEntries: Boolean = dataTypeDetails.get._3 + @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 override def nullable: Boolean = child.nullable || nullEntries - override def dataType: MapType = dataTypeDetails.get._1 + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some(_) => TypeCheckResult.TypeCheckSuccess @@ -949,8 +944,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient - private lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -972,8 +966,7 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient - private lazy val gt: Comparator[Any] = { + @transient private lazy val gt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -995,7 +988,9 @@ trait ArraySortLike extends ExpectsInputTypes { } } - def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + @transient lazy val elementType: DataType = + arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull def sortEval(array: Any, ascending: Boolean): Any = { @@ -1211,7 +1206,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(input: Any): Any = input match { case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) @@ -1601,9 +1596,9 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) - override def children: Seq[Expression] = Seq(x, start, length) + @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval - lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] @@ -1889,7 +1884,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1930,7 +1925,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast min } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -1954,7 +1949,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1995,7 +1990,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast max } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -2097,10 +2092,13 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + + @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull + + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) - override def dataType: DataType = left.dataType match { + @transient override lazy val dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType } @@ -2109,7 +2107,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti Seq(TypeCollection(ArrayType, MapType), left.dataType match { case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _: MapType => mapKeyType case _ => AnyDataType // no match for a wrong 'left' expression type } ) @@ -2119,8 +2117,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => - TypeUtils.checkForOrderingExpr( - left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess } } @@ -2142,14 +2139,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } else { array.numElements() + index } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + if (arrayContainsNull && array.isNullAt(idx)) { null } else { array.get(idx, dataType) } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) + getValueEval(value, ordinal, mapKeyType, ordering) } } @@ -2158,7 +2155,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + val nullCheck = if (arrayContainsNull) { s""" |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true; @@ -2209,9 +2206,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti """) case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - val allowedTypes = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2228,7 +2223,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } - override def dataType: DataType = { + @transient override lazy val dataType: DataType = { if (children.isEmpty) { StringType } else { @@ -2236,7 +2231,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } - lazy val javaType: String = CodeGenerator.javaType(dataType) + private def javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) @@ -2256,9 +2251,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } else { val arrayData = inputs.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 @@ -2316,9 +2312,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio |for (int z = 0; z < ${children.length}; z++) { | $numElements += args[z].numElements(); |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -2413,15 +2410,13 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] override def nullable: Boolean = child.nullable || childDataType.containsNull - override def dataType: DataType = childDataType.elementType + @transient override lazy val dataType: DataType = childDataType.elementType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -2441,9 +2436,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { val arrayData = elements.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + s"$numberOfElements elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 @@ -2476,9 +2472,10 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > $MAX_ARRAY_LENGTH) { + |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | $variableName + " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin (code, variableName) @@ -2602,7 +2599,7 @@ case class Sequence( override def nullable: Boolean = children.exists(_.nullable) - override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) override def checkInputDataTypes(): TypeCheckResult = { val startType = start.dataType @@ -2633,7 +2630,7 @@ case class Sequence( stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), timeZoneId) - private lazy val impl: SequenceImpl = dataType.elementType match { + @transient private lazy val impl: SequenceImpl = dataType.elementType match { case iType: IntegralType => type T = iType.InternalType val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) @@ -2953,8 +2950,6 @@ object Sequence { case class ArrayRepeat(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) @@ -2966,9 +2961,9 @@ case class ArrayRepeat(left: Expression, right: Expression) if (count == null) { null } else { - if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + - s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); } val element = left.eval(input) new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) @@ -3027,9 +3022,10 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -3111,7 +3107,7 @@ case class ArrayRemove(left: Expression, right: Expression) Seq(ArrayType, elementType) } - lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -3228,7 +3224,7 @@ case class ArrayDistinct(child: Expression) override def dataType: DataType = child.dataType - @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType)