From 19bd17c6c7447a2c23ad761e481e41d60a6c87a5 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 22 May 2016 02:56:26 +0900 Subject: [PATCH 01/10] Go to codegen fallback if the total number of NewInstances whose code has been generated exceed the pre-defined threshold --- .../expressions/codegen/CodeGenerator.scala | 3 + .../expressions/objects/objects.scala | 56 ++++++++++++++--- .../org/apache/spark/sql/types/DataType.scala | 23 ++++++- .../spark/sql/DataFrameComplexTypeSuite.scala | 61 +++++++++++++++++++ 4 files changed, 134 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 8b74d606dbb26..91edd21463bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -185,6 +185,9 @@ class CodegenContext { final val JAVA_FLOAT = "float" final val JAVA_DOUBLE = "double" + /** The number of total complex types whose code has been generated **/ + var accumulatedComplexTypeGenCode = 1 + /** The variable name of the input row in generated code. */ final var INPUT_ROW = "i" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5e17f8920901a..5c564f421b133 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -57,8 +57,16 @@ case class StaticInvoke( override def nullable: Boolean = true override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + override def eval(input: InternalRow): Any = { + val argVals = arguments.map(_.eval(input).asInstanceOf[AnyRef]) + if (argVals.find(_ == null).isDefined) { + return null + } + + val argTypes = arguments.map(e => DataType.javaType(e.dataType)) + staticObject.getMethod(functionName, argTypes: _*) + .invoke(null, argVals: _*) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) @@ -112,10 +120,20 @@ case class Invoke( propagateNull: Boolean = true) extends Expression with NonSQLExpression { override def nullable: Boolean = true + override def children: Seq[Expression] = targetObject +: arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + override def eval(input: InternalRow): Any = { + val argVals = arguments.map(_.eval(input).asInstanceOf[AnyRef]) + if (argVals.find(_ == null).isDefined) { + return null + } + + val argTypes = arguments.map(e => DataType.javaType(e.dataType)) + DataType.javaType(targetObject.dataType) + .getMethod(functionName, argTypes: _*) + .invoke(targetObject.eval(input), argVals: _*) + } @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => @@ -210,7 +228,8 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { + outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression + with CodegenFallback { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -227,10 +246,31 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + override def eval(input: InternalRow): Any = { + if (arguments.length == 0) { + return cls.getConstructor().newInstance() + } + + val argVals = arguments.map(_.eval(input).asInstanceOf[AnyRef]) + if (argVals.find(_ == null).isDefined) { + return null + } + + if (outerPointer.isDefined) { + throw new UnsupportedOperationException( + "For inner class, only code-generated evaluation is supported.") + } + + val argTypes = arguments.map(e => DataType.javaType(e.dataType)) + cls.getConstructor(argTypes: _*) + .newInstance(argVals: _*) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ctx.accumulatedComplexTypeGenCode *= arguments.length + if (ctx.accumulatedComplexTypeGenCode > 400) { + return super[CodegenFallback].doGenCode(ctx, ev) + } val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4fc65cbce15bd..879fedcc3caf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -21,9 +21,10 @@ import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ - import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -245,4 +246,24 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + def javaType(dt: DataType): Class[_] = dt match { + case BooleanType => java.lang.Boolean.TYPE + case ByteType => java.lang.Byte.TYPE + case ShortType => java.lang.Short.TYPE + case IntegerType | DateType => java.lang.Integer.TYPE + case LongType | TimestampType => java.lang.Long.TYPE + case FloatType => java.lang.Float.TYPE + case DoubleType => java.lang.Double.TYPE + case dt: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => InternalRow.getClass + case _: ArrayType => classOf[org.apache.spark.sql.catalyst.util.ArrayData] + case _: MapType => classOf[org.apache.spark.sql.catalyst.util.MapData] + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) => cls + case _ => classOf[java.lang.Object] + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 72f676e6225ee..d89b4c06d402d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -58,4 +58,65 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val nullIntRow = df.selectExpr("i[1]").collect()(0) assert(nullIntRow == org.apache.spark.sql.Row(null)) } + + test("SPARK-15285 Nested/Chained case statements generate codegen over 64kb") { + val ds100_5 = Seq(S100_5()).toDS() + ds100_5.show + val ds10_50 = Seq(S10_50()).toDS() + ds10_50.show + val ds10_7_6 = Seq(S10_7_6()).toDS() + ds10_7_6.show + } } + +case class S100( + s1: String = "1", s2: String = "2", s3: String = "3", s4: String = "4", + s5: String = "5", s6: String = "6", s7: String = "7", s8: String = "8", + s9: String = "9", s10: String = "10", s11: String = "11", s12: String = "12", + s13: String = "13", s14: String = "14", s15: String = "15", s16: String = "16", + s17: String = "17", s18: String = "18", s19: String = "19", s20: String = "20", + s21: String = "21", s22: String = "22", s23: String = "23", s24: String = "24", + s25: String = "25", s26: String = "26", s27: String = "27", s28: String = "28", + s29: String = "29", s30: String = "30", s31: String = "31", s32: String = "32", + s33: String = "33", s34: String = "34", s35: String = "35", s36: String = "36", + s37: String = "37", s38: String = "38", s39: String = "39", s40: String = "40", + s41: String = "41", s42: String = "42", s43: String = "43", s44: String = "44", + s45: String = "45", s46: String = "46", s47: String = "47", s48: String = "48", + s49: String = "49", s50: String = "50", s51: String = "51", s52: String = "52", + s53: String = "53", s54: String = "54", s55: String = "55", s56: String = "56", + s57: String = "57", s58: String = "58", s59: String = "59", s60: String = "60", + s61: String = "61", s62: String = "62", s63: String = "63", s64: String = "64", + s65: String = "65", s66: String = "66", s67: String = "67", s68: String = "68", + s69: String = "69", s70: String = "70", s71: String = "71", s72: String = "72", + s73: String = "73", s74: String = "74", s75: String = "75", s76: String = "76", + s77: String = "77", s78: String = "78", s79: String = "79", s80: String = "80", + s81: String = "81", s82: String = "82", s83: String = "83", s84: String = "84", + s85: String = "85", s86: String = "86", s87: String = "87", s88: String = "88", + s89: String = "89", s90: String = "90", s91: String = "91", s92: String = "92", + s93: String = "93", s94: String = "94", s95: String = "95", s96: String = "96", + s97: String = "97", s98: String = "98", s99: String = "99", s100: String = "100") +case class S100_5( + s1: S100 = S100(), s2: S100 = S100(), s3: S100 = S100(), s4: S100 = S100(), s5: S100 = S100()) + +case class S10( + s1: String = "1", s2: String = "2", s3: String = "3", s4: String = "4", s5: String = "5", + s6: String = "6", s7: String = "7", s8: String = "8", s9: String = "9", s10: String = "10") +case class S10_50( + s1: S10 = S10(), s2: S10 = S10(), s3: S10 = S10(), s4: S10 = S10(), s5: S10 = S10(), + s6: S10 = S10(), s7: S10 = S10(), s8: S10 = S10(), s9: S10 = S10(), s10: S10 = S10(), + s11: S10 = S10(), s12: S10 = S10(), s13: S10 = S10(), s14: S10 = S10(), s15: S10 = S10(), + s16: S10 = S10(), s17: S10 = S10(), s18: S10 = S10(), s19: S10 = S10(), s20: S10 = S10(), + s21: S10 = S10(), s22: S10 = S10(), s23: S10 = S10(), s24: S10 = S10(), s25: S10 = S10(), + s26: S10 = S10(), s27: S10 = S10(), s28: S10 = S10(), s29: S10 = S10(), s30: S10 = S10(), + s31: S10 = S10(), s32: S10 = S10(), s33: S10 = S10(), s34: S10 = S10(), s35: S10 = S10(), + s36: S10 = S10(), s37: S10 = S10(), s38: S10 = S10(), s39: S10 = S10(), s40: S10 = S10(), + s41: S10 = S10(), s42: S10 = S10(), s43: S10 = S10(), s44: S10 = S10(), s45: S10 = S10(), + s46: S10 = S10(), s47: S10 = S10(), s48: S10 = S10(), s49: S10 = S10(), s50: S10 = S10()) + +case class S10_7( + s1: S10 = S10(), s2: S10 = S10(), s3: S10 = S10(), s4: S10 = S10(), s5: S10 = S10(), + s6: S10 = S10(), s7: S10 = S10()) + +case class S10_7_6( + s1: S10_7 = S10_7(), s2: S10_7 = S10_7(), s3: S10_7 = S10_7(), s4: S10_7 = S10_7(), + s5: S10_7 = S10_7(), s6: S10_7 = S10_7()) From 7db7f915c202e516d4e22e688f2a61cf43453346 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 22 May 2016 02:59:04 +0900 Subject: [PATCH 02/10] revert blank line --- .../src/main/scala/org/apache/spark/sql/types/DataType.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 879fedcc3caf9..9144db93bf8ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -21,6 +21,7 @@ import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression From 0e5e44e5fd03c8f5e50aa75518d01e5b579d013a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 22 May 2016 03:21:11 +0900 Subject: [PATCH 03/10] fix scala style error --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 91edd21463bda..f190fe5bf9538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -185,7 +185,7 @@ class CodegenContext { final val JAVA_FLOAT = "float" final val JAVA_DOUBLE = "double" - /** The number of total complex types whose code has been generated **/ + /* The number of total complex types whose code has been generated */ var accumulatedComplexTypeGenCode = 1 /** The variable name of the input row in generated code. */ From 7c35c12d7bf0709d1dfc1f62167aa5a82a1ac2a6 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 22 May 2016 05:50:38 -0400 Subject: [PATCH 04/10] supprtCodegen() with NewInstance (with CodegenFallback) expression returns true if its expression is for java or scala class --- .../spark/sql/execution/WholeStageCodegenExec.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 2a1ce735b74ea..dfd453eef71f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate @@ -420,7 +421,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true // CodegenFallback requires the input to be an InternalRow - case e: CodegenFallback => false + case e: CodegenFallback => + if (e.isInstanceOf[NewInstance]) { + // We assume that a class for java or Scala does not lead to CodegenFallback + val className = e.asInstanceOf[NewInstance].cls.getName + return (className.startsWith("java.") || className.startsWith("scala.")) + } + false case _ => true } From 5bdfaa73f8d458b9994dd1a11b535d9413674484 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 22 May 2016 10:35:47 -0400 Subject: [PATCH 05/10] apply an approach to split the generated code --- .../expressions/codegen/CodeGenerator.scala | 3 - .../expressions/objects/objects.scala | 94 ++++++++----------- .../org/apache/spark/sql/types/DataType.scala | 22 ----- .../sql/execution/WholeStageCodegenExec.scala | 9 +- .../spark/sql/DataFrameComplexTypeSuite.scala | 29 +----- 5 files changed, 43 insertions(+), 114 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f190fe5bf9538..8b74d606dbb26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -185,9 +185,6 @@ class CodegenContext { final val JAVA_FLOAT = "float" final val JAVA_DOUBLE = "double" - /* The number of total complex types whose code has been generated */ - var accumulatedComplexTypeGenCode = 1 - /** The variable name of the input row in generated code. */ final var INPUT_ROW = "i" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5c564f421b133..59aea7351a2e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -57,16 +57,8 @@ case class StaticInvoke( override def nullable: Boolean = true override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = { - val argVals = arguments.map(_.eval(input).asInstanceOf[AnyRef]) - if (argVals.find(_ == null).isDefined) { - return null - } - - val argTypes = arguments.map(e => DataType.javaType(e.dataType)) - staticObject.getMethod(functionName, argTypes: _*) - .invoke(null, argVals: _*) - } + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) @@ -120,20 +112,10 @@ case class Invoke( propagateNull: Boolean = true) extends Expression with NonSQLExpression { override def nullable: Boolean = true - override def children: Seq[Expression] = targetObject +: arguments - override def eval(input: InternalRow): Any = { - val argVals = arguments.map(_.eval(input).asInstanceOf[AnyRef]) - if (argVals.find(_ == null).isDefined) { - return null - } - - val argTypes = arguments.map(e => DataType.javaType(e.dataType)) - DataType.javaType(targetObject.dataType) - .getMethod(functionName, argTypes: _*) - .invoke(targetObject.eval(input), argVals: _*) - } + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => @@ -228,8 +210,7 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression - with CodegenFallback { + outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -246,53 +227,60 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = { - if (arguments.length == 0) { - return cls.getConstructor().newInstance() - } - - val argVals = arguments.map(_.eval(input).asInstanceOf[AnyRef]) - if (argVals.find(_ == null).isDefined) { - return null - } + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - if (outerPointer.isDefined) { - throw new UnsupportedOperationException( - "For inner class, only code-generated evaluation is supported.") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = ctx.javaType(dataType) + val argIsNulls = ctx.freshName("argIsNulls") + ctx.addMutableState("boolean[]", argIsNulls, "") + val argTypes = arguments.map(e => ctx.javaType(e.dataType)) + val argValues = arguments.zipWithIndex.map { case (e, i) => + val argValue = ctx.freshName("argValue") + ctx.addMutableState(argTypes(i), argValue, "") + argValue } - val argTypes = arguments.map(e => DataType.javaType(e.dataType)) - cls.getConstructor(argTypes: _*) - .newInstance(argVals: _*) - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.accumulatedComplexTypeGenCode *= arguments.length - if (ctx.accumulatedComplexTypeGenCode > 400) { - return super[CodegenFallback].doGenCode(ctx, ev) + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + expr.code + s""" + $argIsNulls[$i] = ${expr.isNull}; + ${argValues(i)} = ${expr.value}; + """ } - val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") + val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) var isNull = ev.isNull val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};" + if (arguments.length <= 10) { + val argIsNull = arguments.zipWithIndex.map { case (e, i) => + s"$argIsNulls[$i]" + } + s"final boolean $isNull = ${argIsNull.mkString(" || ")};" + } else { + s""" + boolean $isNull = false; + for (int idx = 0; idx < ${arguments.length}; idx++) { + if ($argIsNulls[idx]) { $isNull = true; break; } + } + """ + } } else { isNull = "false" "" } val constructorCall = outer.map { gen => - s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})""" }.getOrElse { - s"new $className($argString)" + s"new $className(${argValues.mkString(", ")})" } val code = s""" - ${argGen.map(_.code).mkString("\n")} + $argIsNulls = new boolean[${arguments.size}]; + ${argCode.mkString("")} ${outer.map(_.code).getOrElse("")} $setIsNull final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 9144db93bf8ec..4fc65cbce15bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -23,9 +23,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -247,24 +245,4 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } - - def javaType(dt: DataType): Class[_] = dt match { - case BooleanType => java.lang.Boolean.TYPE - case ByteType => java.lang.Byte.TYPE - case ShortType => java.lang.Short.TYPE - case IntegerType | DateType => java.lang.Integer.TYPE - case LongType | TimestampType => java.lang.Long.TYPE - case FloatType => java.lang.Float.TYPE - case DoubleType => java.lang.Double.TYPE - case dt: DecimalType => classOf[Decimal] - case BinaryType => classOf[Array[Byte]] - case StringType => classOf[UTF8String] - case CalendarIntervalType => classOf[CalendarInterval] - case _: StructType => InternalRow.getClass - case _: ArrayType => classOf[org.apache.spark.sql.catalyst.util.ArrayData] - case _: MapType => classOf[org.apache.spark.sql.catalyst.util.MapData] - case udt: UserDefinedType[_] => javaType(udt.sqlType) - case ObjectType(cls) => cls - case _ => classOf[java.lang.Object] - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index dfd453eef71f0..2a1ce735b74ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -22,7 +22,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate @@ -421,13 +420,7 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true // CodegenFallback requires the input to be an InternalRow - case e: CodegenFallback => - if (e.isInstanceOf[NewInstance]) { - // We assume that a class for java or Scala does not lead to CodegenFallback - val className = e.asInstanceOf[NewInstance].cls.getName - return (className.startsWith("java.") || className.startsWith("scala.")) - } - false + case e: CodegenFallback => false case _ => true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index d89b4c06d402d..e34492f836279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -59,13 +59,9 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { assert(nullIntRow == org.apache.spark.sql.Row(null)) } - test("SPARK-15285 Nested/Chained case statements generate codegen over 64kb") { + test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") { val ds100_5 = Seq(S100_5()).toDS() ds100_5.show - val ds10_50 = Seq(S10_50()).toDS() - ds10_50.show - val ds10_7_6 = Seq(S10_7_6()).toDS() - ds10_7_6.show } } @@ -97,26 +93,3 @@ case class S100( s97: String = "97", s98: String = "98", s99: String = "99", s100: String = "100") case class S100_5( s1: S100 = S100(), s2: S100 = S100(), s3: S100 = S100(), s4: S100 = S100(), s5: S100 = S100()) - -case class S10( - s1: String = "1", s2: String = "2", s3: String = "3", s4: String = "4", s5: String = "5", - s6: String = "6", s7: String = "7", s8: String = "8", s9: String = "9", s10: String = "10") -case class S10_50( - s1: S10 = S10(), s2: S10 = S10(), s3: S10 = S10(), s4: S10 = S10(), s5: S10 = S10(), - s6: S10 = S10(), s7: S10 = S10(), s8: S10 = S10(), s9: S10 = S10(), s10: S10 = S10(), - s11: S10 = S10(), s12: S10 = S10(), s13: S10 = S10(), s14: S10 = S10(), s15: S10 = S10(), - s16: S10 = S10(), s17: S10 = S10(), s18: S10 = S10(), s19: S10 = S10(), s20: S10 = S10(), - s21: S10 = S10(), s22: S10 = S10(), s23: S10 = S10(), s24: S10 = S10(), s25: S10 = S10(), - s26: S10 = S10(), s27: S10 = S10(), s28: S10 = S10(), s29: S10 = S10(), s30: S10 = S10(), - s31: S10 = S10(), s32: S10 = S10(), s33: S10 = S10(), s34: S10 = S10(), s35: S10 = S10(), - s36: S10 = S10(), s37: S10 = S10(), s38: S10 = S10(), s39: S10 = S10(), s40: S10 = S10(), - s41: S10 = S10(), s42: S10 = S10(), s43: S10 = S10(), s44: S10 = S10(), s45: S10 = S10(), - s46: S10 = S10(), s47: S10 = S10(), s48: S10 = S10(), s49: S10 = S10(), s50: S10 = S10()) - -case class S10_7( - s1: S10 = S10(), s2: S10 = S10(), s3: S10 = S10(), s4: S10 = S10(), s5: S10 = S10(), - s6: S10 = S10(), s7: S10 = S10()) - -case class S10_7_6( - s1: S10_7 = S10_7(), s2: S10_7 = S10_7(), s3: S10_7 = S10_7(), s4: S10_7 = S10_7(), - s5: S10_7 = S10_7(), s6: S10_7 = S10_7()) From 74d876453cf853a97884ff195616e909e5873a2c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 22 May 2016 22:34:24 -0400 Subject: [PATCH 06/10] addressed review comments --- .../spark/sql/catalyst/expressions/objects/objects.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 59aea7351a2e7..64373566ec2b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -233,11 +233,11 @@ case class NewInstance( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val argIsNulls = ctx.freshName("argIsNulls") - ctx.addMutableState("boolean[]", argIsNulls, "") - val argTypes = arguments.map(e => ctx.javaType(e.dataType)) + ctx.addMutableState("boolean[]", argIsNulls, + s"$argIsNulls = new boolean[${arguments.size}];") val argValues = arguments.zipWithIndex.map { case (e, i) => val argValue = ctx.freshName("argValue") - ctx.addMutableState(argTypes(i), argValue, "") + ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") argValue } @@ -279,7 +279,6 @@ case class NewInstance( } val code = s""" - $argIsNulls = new boolean[${arguments.size}]; ${argCode.mkString("")} ${outer.map(_.code).getOrElse("")} $setIsNull From 9cc4d41ecb69c0b6fecf11da5683f1114d418e4d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 May 2016 01:29:57 -0400 Subject: [PATCH 07/10] addressed review comments --- .../expressions/objects/objects.scala | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 64373566ec2b3..dc4b9ed95bf56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -254,19 +254,12 @@ case class NewInstance( var isNull = ev.isNull val setIsNull = if (propagateNull && arguments.nonEmpty) { - if (arguments.length <= 10) { - val argIsNull = arguments.zipWithIndex.map { case (e, i) => - s"$argIsNulls[$i]" - } - s"final boolean $isNull = ${argIsNull.mkString(" || ")};" - } else { - s""" - boolean $isNull = false; - for (int idx = 0; idx < ${arguments.length}; idx++) { - if ($argIsNulls[idx]) { $isNull = true; break; } - } - """ - } + s""" + boolean $isNull = false; + for (int idx = 0; idx < ${arguments.length}; idx++) { + if ($argIsNulls[idx]) { $isNull = true; break; } + } + """ } else { isNull = "false" "" @@ -279,7 +272,7 @@ case class NewInstance( } val code = s""" - ${argCode.mkString("")} + ${argCode} ${outer.map(_.code).getOrElse("")} $setIsNull final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; From 6e730ca6215afca78c0d3d34bd07218962475aaa Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 May 2016 05:44:24 -0400 Subject: [PATCH 08/10] addressed review comments --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index dc4b9ed95bf56..87bb0c5b295f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -272,7 +272,7 @@ case class NewInstance( } val code = s""" - ${argCode} + $argCode ${outer.map(_.code).getOrElse("")} $setIsNull final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; From 6f5bda30d2e24347501f02881132a734b0ed1d7c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 May 2016 14:38:28 -0400 Subject: [PATCH 09/10] addressed review comments --- .../scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index e34492f836279..12a464af806fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -91,5 +91,6 @@ case class S100( s89: String = "89", s90: String = "90", s91: String = "91", s92: String = "92", s93: String = "93", s94: String = "94", s95: String = "95", s96: String = "96", s97: String = "97", s98: String = "98", s99: String = "99", s100: String = "100") + case class S100_5( s1: S100 = S100(), s2: S100 = S100(), s3: S100 = S100(), s4: S100 = S100(), s5: S100 = S100()) From b92ab8cb8fd4d9aabda7486cf9d2d5dd7b98f0bf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 May 2016 21:59:16 -0400 Subject: [PATCH 10/10] addressed review comments --- .../scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 12a464af806fe..07fbaba0f819c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -61,7 +61,7 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") { val ds100_5 = Seq(S100_5()).toDS() - ds100_5.show + ds100_5.rdd.count } }