From fa42d97f8c4ed0193326cb00b5dfcb5c9001175e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 28 Feb 2017 04:35:04 +0900 Subject: [PATCH] Retry an execution by calling eval() if caught an exception --- .../expressions/codegen/CodeGenerator.scala | 27 ++++++++++++++++--- .../sql/catalyst/expressions/predicates.scala | 10 ++++--- .../spark/sql/execution/SparkPlan.scala | 24 ++++++++++++++++- .../PartitioningAwareFileIndex.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 12 +++++++++ .../spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../sql/sources/SimpleTextRelation.scala | 2 +- 7 files changed, 67 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 760ead42c762c..9591025471bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,7 +27,10 @@ import scala.language.existentials import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} -import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler} +import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} +import org.apache.commons.lang3.exception.ExceptionUtils +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} import org.codehaus.janino.util.ClassFile import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException} @@ -899,8 +902,20 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - def compile(code: CodeAndComment): GeneratedClass = { + def compile(code: CodeAndComment): GeneratedClass = try { cache.get(code) + } catch { + // Cache.get() may wrap the original exception. See the following URL + // http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/ + // Cache.html#get(K,%20java.util.concurrent.Callable) + case e : UncheckedExecutionException => + val excChains = ExceptionUtils.getThrowables(e) + val exc = if (excChains.length == 1) excChains(0) else excChains(excChains.length - 2) + throw exc + case e : ExecutionError => + val excChains = ExceptionUtils.getThrowables(e) + val exc = if (excChains.length == 1) excChains(0) else excChains(excChains.length - 2) + throw exc } /** @@ -951,10 +966,14 @@ object CodeGenerator extends Logging { evaluator.cook("generated.java", code.body) recordCompilationStats(evaluator) } catch { - case e: Exception => + case e: JaninoRuntimeException => val msg = s"failed to compile: $e\n$formatted" logError(msg, e) - throw new Exception(msg, e) + throw new JaninoRuntimeException(msg, e) + case e: CompileException => + val msg = s"failed to compile: $e\n$formatted" + logError(msg, e) + throw new CompileException(msg, e.asInstanceOf[CompileException].getLocation) } evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac56ff13fa5bf..6552c0d8e0267 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,20 +20,22 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = + def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = create(BindReferences.bindReference(expression, inputSchema)) - def create(expression: Expression): (InternalRow => Boolean) = { - (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] - } + def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression) } +class InterpretedPredicate(expression: Expression) extends BasePredicate { + def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] +} /** * An [[Expression]] that returns a boolean value. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index cadab37a449aa..64a44696bcc77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -22,6 +22,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.JaninoRuntimeException + import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec @@ -353,9 +356,28 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } + private def genInterpretedPredicate( + expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = { + val str = expression.toString + val logMessage = if (str.length > 256) { + str.substring(0, 256 - 3) + "..." + } else { + str + } + logWarning(s"Codegen disabled for this expression:\n $logMessage") + InterpretedPredicate.create(expression, inputSchema) + } + protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { - GeneratePredicate.generate(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: JaninoRuntimeException if sqlContext == null || sqlContext.conf.wholeStageFallback => + genInterpretedPredicate(expression, inputSchema) + case e: CompileException if sqlContext == null || sqlContext.conf.wholeStageFallback => + genInterpretedPredicate(expression, inputSchema) + } } protected def newOrdering( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index c8097a7fabc2e..75b5384d355e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -180,7 +180,7 @@ abstract class PartitioningAwareFileIndex( }) val selected = partitions.filter { - case PartitionPath(values, _) => boundPredicate(values) + case PartitionPath(values, _) => boundPredicate.eval(values) } logInfo { val total = partitions.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 19c2d5532d088..46919d5b6556c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -24,6 +24,7 @@ import java.util.UUID import scala.util.Random +import com.sun.net.httpserver.Authenticator.Retry import org.scalatest.Matchers._ import org.apache.spark.SparkException @@ -1703,4 +1704,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) } + + test("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { + val N = 400 + val rows = Seq(Row.fromSeq(Seq.fill(N)("string"))) + val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType))) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) + + val filter = (0 until N) + .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) + df.filter(filter).count + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 43d9c2bec6823..d09160f4e4da9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1036,7 +1036,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat BoundReference(index, partitionSchema(index).dataType, nullable = true) }) clientPrunedPartitions.filter { p => - boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) } } else { client.getPartitions(catalogTable).map { part => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 1607c97cd6acb..b04ef6f21d69f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -103,7 +103,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { // `Cast`ed values are always of internal types (e.g. UTF8String instead of String) Cast(Literal(value), dataType).eval() }) - }.filter(predicate).map(projection) + }.filter(predicate.eval).map(projection) // Appends partition values val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes