Skip to content

Commit

Permalink
[SPARK-24305][SQL][FOLLOWUP] Using def wherever it's possible and cov…
Browse files Browse the repository at this point in the history
…ering more expressions.
  • Loading branch information
mn-mikke committed Jul 16, 2018
1 parent a4d1e7f commit 62c55ad
Showing 1 changed file with 30 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,27 +168,21 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI

override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)

override def dataType: DataType = ArrayType(mountSchema)

override def nullable: Boolean = children.exists(_.nullable)

private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])

private lazy val arrayElementTypes = arrayTypes.map(_.elementType)

@transient private lazy val mountSchema: StructType = {
@transient override lazy val dataType: DataType = {
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
case ((expr: NamedExpression, elementType), _) =>
StructField(expr.name, elementType, nullable = true)
case ((_, elementType), idx) =>
StructField(idx.toString, elementType, nullable = true)
}
StructType(fields)
ArrayType(StructType(fields), containsNull = false)
}

@transient lazy val numberOfArrays: Int = children.length
override def nullable: Boolean = children.exists(_.nullable)

private def arrayElementTypes = children.map(_.dataType.asInstanceOf[ArrayType].elementType)

@transient lazy val genericArrayData = classOf[GenericArrayData].getName
private def genericArrayData = classOf[GenericArrayData].getName

def emptyInputGenCode(ev: ExprCode): ExprCode = {
ev.copy(code"""
Expand Down Expand Up @@ -256,7 +250,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
("ArrayData[]", arrVals) :: Nil)

val initVariables = s"""
|ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
|ArrayData[] $arrVals = new ArrayData[${children.length}];
|int $biggestCardinality = 0;
|${CodeGenerator.javaType(dataType)} ${ev.value} = null;
""".stripMargin
Expand All @@ -268,7 +262,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
|if (!${ev.isNull}) {
| Object[] $args = new Object[$biggestCardinality];
| for (int $i = 0; $i < $biggestCardinality; $i ++) {
| Object[] $currentRow = new Object[$numberOfArrays];
| Object[] $currentRow = new Object[${children.length}];
| $getValueForTypeSplitted
| $args[$i] = new $genericInternalRow($currentRow);
| }
Expand All @@ -278,7 +272,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (numberOfArrays == 0) {
if (children.length == 0) {
emptyInputGenCode(ev)
} else {
nonEmptyInputGenCode(ctx, ev)
Expand Down Expand Up @@ -360,7 +354,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
private def childDataType: MapType = child.dataType.asInstanceOf[MapType]

override def dataType: DataType = {
ArrayType(
Expand Down Expand Up @@ -741,14 +735,15 @@ case class MapConcat(children: Seq[Expression]) extends Expression {
since = "2.4.0")
case class MapFromEntries(child: Expression) extends UnaryExpression {

@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
case ArrayType(
@transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = {
child.dataType match {
case ArrayType(
StructType(Array(
StructField(_, keyType, keyNullable, _),
StructField(_, valueType, valueNullable, _))),
StructField(_, keyType, keyNullable, _),
StructField(_, valueType, valueNullable, _))),
containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull))
case _ => None
case _ => None
}
}

private def nullEntries: Boolean = dataTypeDetails.get._3
Expand Down Expand Up @@ -953,8 +948,7 @@ trait ArraySortLike extends ExpectsInputTypes {

protected def nullOrder: NullOrder

@transient
private lazy val lt: Comparator[Any] = {
@transient private lazy val lt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand All @@ -976,8 +970,7 @@ trait ArraySortLike extends ExpectsInputTypes {
}
}

@transient
private lazy val gt: Comparator[Any] = {
@transient private lazy val gt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand Down Expand Up @@ -1215,8 +1208,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI

override def dataType: DataType = child.dataType

@transient
private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
Expand Down Expand Up @@ -1608,8 +1600,7 @@ case class Slice(x: Expression, start: Expression, length: Expression)

override def children: Seq[Expression] = Seq(x, start, length)

@transient
private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
private def elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
val startInt = startVal.asInstanceOf[Int]
Expand Down Expand Up @@ -1895,8 +1886,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

@transient
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
@transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
Expand Down Expand Up @@ -1961,8 +1951,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

@transient
private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
@transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
Expand Down Expand Up @@ -2236,8 +2225,7 @@ case class Concat(children: Seq[Expression]) extends Expression {

override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)

@transient
lazy val javaType: String = CodeGenerator.javaType(dataType)
private def javaType: String = CodeGenerator.javaType(dataType)

override def nullable: Boolean = children.exists(_.nullable)

Expand Down Expand Up @@ -2416,15 +2404,13 @@ case class Concat(children: Seq[Expression]) extends Expression {
since = "2.4.0")
case class Flatten(child: Expression) extends UnaryExpression {

@transient
private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]

override def nullable: Boolean = child.nullable || childDataType.containsNull

override def dataType: DataType = childDataType.elementType

@transient
private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(_: ArrayType, _) =>
Expand Down Expand Up @@ -2607,7 +2593,7 @@ case class Sequence(

override def nullable: Boolean = children.exists(_.nullable)

override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false)

override def checkInputDataTypes(): TypeCheckResult = {
val startType = start.dataType
Expand Down Expand Up @@ -2638,7 +2624,7 @@ case class Sequence(
stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step),
timeZoneId)

private lazy val impl: SequenceImpl = dataType.elementType match {
@transient private lazy val impl: SequenceImpl = dataType.elementType match {
case iType: IntegralType =>
type T = iType.InternalType
val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
Expand Down Expand Up @@ -3115,7 +3101,7 @@ case class ArrayRemove(left: Expression, right: Expression)
Seq(ArrayType, elementType)
}

lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
Expand Down Expand Up @@ -3232,7 +3218,7 @@ case class ArrayDistinct(child: Expression)

override def dataType: DataType = child.dataType

@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
private def elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)
Expand Down

0 comments on commit 62c55ad

Please sign in to comment.