diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index accf2c72761bd..aac10df22853b 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -19,17 +19,18 @@ Collate: 'jobj.R' 'RDD.R' 'pairRDD.R' + 'SQLTypes.R' 'column.R' 'group.R' 'schema.R' 'DataFrame.R' 'SQLContext.R' + 'backend.R' 'broadcast.R' + 'client.R' 'context.R' 'deserialize.R' 'serialize.R' 'sparkR.R' - 'backend.R' - 'client.R' 'utils.R' 'zzz.R' diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6ecd35a26d026..21ef8323915fb 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -17,7 +17,7 @@ # DataFrame.R - DataFrame class and methods implemented in S4 OO classes -#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R NULL setOldClass("jobj") diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index e196305186b9a..b282001d8b6b5 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -17,7 +17,7 @@ # Column Class -#' @include generics.R jobj.R +#' @include generics.R jobj.R SQLTypes.R NULL setOldClass("jobj") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 09fc0a7abe48a..855fbdfc7c4ca 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -17,6 +17,9 @@ # group.R - GroupedData class and methods implemented in S4 OO classes +#' @include generics.R jobj.R SQLTypes.R column.R +NULL + setOldClass("jobj") #' @title S4 class that represents a GroupedData diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index 4180f146b7fbc..a8a25230b636d 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -18,6 +18,9 @@ # References to objects that exist on the JVM backend # are maintained using the jobj. +#' @include generics.R +NULL + # Maintain a reference count of Java object references # This allows us to GC the java object when it is safe .validJobjs <- new.env(parent = emptyenv()) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 02518f6d5bb4c..13efebc11c46e 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -16,6 +16,8 @@ # # Operations supported on RDDs contains pairs (i.e key, value) +#' @include generics.R jobj.R RDD.R +NULL ############ Actions and Transformations ############ diff --git a/bin/spark-class b/bin/spark-class index c03946d92e2e4..c49d97ce5cf25 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -82,13 +82,22 @@ if [ $(command -v "$JAR_CMD") ] ; then fi fi +LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" + +# Add the launcher build dir to the classpath if requested. +if [ -n "$SPARK_PREPEND_CLASSES" ]; then + LAUNCH_CLASSPATH="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" +fi + +export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" + # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") -done < <("$RUNNER" -cp "$SPARK_ASSEMBLY_JAR" org.apache.spark.launcher.Main "$@") +done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") if [ "${CMD[0]}" = "usage" ]; then "${CMD[@]}" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 4b3401d745f2a..3d068dd3a2739 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -46,13 +46,22 @@ if "%SPARK_ASSEMBLY_JAR%"=="0" ( exit /b 1 ) +set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% + +rem Add the launcher build dir to the classpath if requested. +if not "x%SPARK_PREPEND_CLASSES%"=="x" ( + set LAUNCH_CLASSPATH=%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH% +) + +set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% + rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %SPARK_ASSEMBLY_JAR% org.apache.spark.launcher.Main %*"') do ( +for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( set SPARK_CMD=%%i ) %SPARK_CMD% diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3f1a7dd99d635..e106c5c4bef60 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -31,6 +31,7 @@ import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -50,9 +51,10 @@ import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -192,8 +194,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - private[spark] val conf = config.clone() - conf.validateSettings() + /* ------------------------------------------------------------------------------------- * + | Private variables. These variables keep the internal state of the context, and are | + | not accessible by the outside world. They're mutable since we want to initialize all | + | of them to some neutral value ahead of time, so that calling "stop()" while the | + | constructor is still running is safe. | + * ------------------------------------------------------------------------------------- */ + + private var _conf: SparkConf = _ + private var _eventLogDir: Option[URI] = None + private var _eventLogCodec: Option[String] = None + private var _env: SparkEnv = _ + private var _metadataCleaner: MetadataCleaner = _ + private var _jobProgressListener: JobProgressListener = _ + private var _statusTracker: SparkStatusTracker = _ + private var _progressBar: Option[ConsoleProgressBar] = None + private var _ui: Option[SparkUI] = None + private var _hadoopConfiguration: Configuration = _ + private var _executorMemory: Int = _ + private var _schedulerBackend: SchedulerBackend = _ + private var _taskScheduler: TaskScheduler = _ + private var _heartbeatReceiver: RpcEndpointRef = _ + @volatile private var _dagScheduler: DAGScheduler = _ + private var _applicationId: String = _ + private var _eventLogger: Option[EventLoggingListener] = None + private var _executorAllocationManager: Option[ExecutorAllocationManager] = None + private var _cleaner: Option[ContextCleaner] = None + private var _listenerBusStarted: Boolean = false + private var _jars: Seq[String] = _ + private var _files: Seq[String] = _ + + /* ------------------------------------------------------------------------------------- * + | Accessors and public fields. These provide access to the internal state of the | + | context. | + * ------------------------------------------------------------------------------------- */ + + private[spark] def conf: SparkConf = _conf /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be @@ -201,65 +237,24 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def getConf: SparkConf = conf.clone() - if (!conf.contains("spark.master")) { - throw new SparkException("A master URL must be set in your configuration") - } - if (!conf.contains("spark.app.name")) { - throw new SparkException("An application name must be set in your configuration") - } - - if (conf.getBoolean("spark.logConf", false)) { - logInfo("Spark configuration:\n" + conf.toDebugString) - } - - // Set Spark driver host and port system properties - conf.setIfMissing("spark.driver.host", Utils.localHostName()) - conf.setIfMissing("spark.driver.port", "0") - - val jars: Seq[String] = - conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val files: Seq[String] = - conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val master = conf.get("spark.master") - val appName = conf.get("spark.app.name") + def jars: Seq[String] = _jars + def files: Seq[String] = _files + def master: String = _conf.get("spark.master") + def appName: String = _conf.get("spark.app.name") - private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[URI] = { - if (isEventLogEnabled) { - val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) - .stripSuffix("/") - Some(Utils.resolveURI(unresolvedDir)) - } else { - None - } - } - private[spark] val eventLogCodec: Option[String] = { - val compress = conf.getBoolean("spark.eventLog.compress", false) - if (compress && isEventLogEnabled) { - Some(CompressionCodec.getCodecName(conf)).map(CompressionCodec.getShortName) - } else { - None - } - } + private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def eventLogDir: Option[URI] = _eventLogDir + private[spark] def eventLogCodec: Option[String] = _eventLogCodec // Generate the random name for a temp folder in Tachyon // Add a timestamp as the suffix here to make it more safe val tachyonFolderName = "spark-" + randomUUID.toString() - conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val isLocal = (master == "local" || master.startsWith("local[")) - - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + def isLocal: Boolean = (master == "local" || master.startsWith("local[")) // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus - conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - - // Create the Spark execution environment (cache, map output tracker, etc) - // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( conf: SparkConf, @@ -268,8 +263,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkEnv.createDriverEnv(conf, isLocal, listenerBus) } - private[spark] val env = createSparkEnv(conf, isLocal, listenerBus) - SparkEnv.set(env) + private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -277,35 +271,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) - + private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener - private[spark] val jobProgressListener = new JobProgressListener(conf) - listenerBus.addListener(jobProgressListener) + def statusTracker: SparkStatusTracker = _statusTracker - val statusTracker = new SparkStatusTracker(this) + private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar - private[spark] val progressBar: Option[ConsoleProgressBar] = - if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { - Some(new ConsoleProgressBar(this)) - } else { - None - } - - // Initialize the Spark UI - private[spark] val ui: Option[SparkUI] = - if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, - env.securityManager,appName)) - } else { - // For tests, do not enable the UI - None - } - - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - ui.foreach(_.bind()) + private[spark] def ui: Option[SparkUI] = _ui /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. @@ -313,134 +286,248 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ - val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) + def hadoopConfiguration: Configuration = _hadoopConfiguration + + private[spark] def executorMemory: Int = _executorMemory + + // Environment variables to pass to our executors. + private[spark] val executorEnvs = HashMap[String, String]() + + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Utils.getCurrentUserName() - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach(addJar) + private[spark] def schedulerBackend: SchedulerBackend = _schedulerBackend + private[spark] def schedulerBackend_=(sb: SchedulerBackend): Unit = { + _schedulerBackend = sb } - if (files != null) { - files.foreach(addFile) + private[spark] def taskScheduler: TaskScheduler = _taskScheduler + private[spark] def taskScheduler_=(ts: TaskScheduler): Unit = { + _taskScheduler = ts } + private[spark] def dagScheduler: DAGScheduler = _dagScheduler + private[spark] def dagScheduler_=(ds: DAGScheduler): Unit = { + _dagScheduler = ds + } + + def applicationId: String = _applicationId + + def metricsSystem: MetricsSystem = if (_env != null) _env.metricsSystem else null + + private[spark] def eventLogger: Option[EventLoggingListener] = _eventLogger + + private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = + _executorAllocationManager + + private[spark] def cleaner: Option[ContextCleaner] = _cleaner + + private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } + + /* ------------------------------------------------------------------------------------- * + | Initialization. This code initializes the context in a manner that is exception-safe. | + | All internal fields holding state are initialized here, and any error prompts the | + | stop() method to be called. | + * ------------------------------------------------------------------------------------- */ + private def warnSparkMem(value: String): String = { logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + "deprecated, please use spark.executor.memory instead.") value } - private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) - .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) - .map(Utils.memoryStringToMb) - .getOrElse(512) + try { + _conf = config.clone() + _conf.validateSettings() - // Environment variables to pass to our executors. - private[spark] val executorEnvs = HashMap[String, String]() + if (!_conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") + } + if (!_conf.contains("spark.app.name")) { + throw new SparkException("An application name must be set in your configuration") + } - // Convert java options to env vars as a work around - // since we can't set env vars directly in sbt. - for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) - value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { - executorEnvs(envKey) = value - } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => - executorEnvs("SPARK_PREPEND_CLASSES") = v - } - // The Mesos scheduler backend relies on this environment variable to set executor memory. - // TODO: Set this only in the Mesos scheduler. - executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" - executorEnvs ++= conf.getExecutorEnv + if (_conf.getBoolean("spark.logConf", false)) { + logInfo("Spark configuration:\n" + _conf.toDebugString) + } - // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Utils.getCurrentUserName() - executorEnvs("SPARK_USER") = sparkUser + // Set Spark driver host and port system properties + _conf.setIfMissing("spark.driver.host", Utils.localHostName()) + _conf.setIfMissing("spark.driver.port", "0") - // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will - // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) - private val heartbeatReceiver = env.rpcEnv.setupEndpoint( - HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - // Create and start the scheduler - private[spark] var (schedulerBackend, taskScheduler) = - SparkContext.createTaskScheduler(this, master) + _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) + .toSeq.flatten - heartbeatReceiver.send(TaskSchedulerIsSet) + _eventLogDir = + if (isEventLogEnabled) { + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) + } else { + None + } - @volatile private[spark] var dagScheduler: DAGScheduler = _ - try { - dagScheduler = new DAGScheduler(this) - } catch { - case e: Exception => { - try { - stop() - } finally { - throw new SparkException("Error while constructing DAGScheduler", e) + _eventLogCodec = { + val compress = _conf.getBoolean("spark.eventLog.compress", false) + if (compress && isEventLogEnabled) { + Some(CompressionCodec.getCodecName(_conf)).map(CompressionCodec.getShortName) + } else { + None } } - } - // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's - // constructor - taskScheduler.start() + _conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val applicationId: String = taskScheduler.applicationId() - conf.set("spark.app.id", applicationId) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - env.blockManager.initialize(applicationId) + // Create the Spark execution environment (cache, map output tracker, etc) + _env = createSparkEnv(_conf, isLocal, listenerBus) + SparkEnv.set(_env) - val metricsSystem = env.metricsSystem + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - // The metrics system for Driver need to be set spark.app.id to app ID. - // So it should start after we get app ID from the task scheduler and set spark.app.id. - metricsSystem.start() - // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + _jobProgressListener = new JobProgressListener(_conf) + listenerBus.addListener(jobProgressListener) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (isEventLogEnabled) { - val logger = - new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } + _statusTracker = new SparkStatusTracker(this) - // Optionally scale number of executors dynamically based on workload. Exposed for testing. - private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) - private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false) - private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = - if (dynamicAllocationEnabled) { - assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") - Some(new ExecutorAllocationManager(this, listenerBus, conf)) - } else { - None + _progressBar = + if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + + _ui = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + _env.securityManager,appName)) + } else { + // For tests, do not enable the UI + None + } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + _ui.foreach(_.bind()) + + _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) + + // Add each JAR given through the constructor + if (jars != null) { + jars.foreach(addJar) } - executorAllocationManager.foreach(_.start()) - private[spark] val cleaner: Option[ContextCleaner] = { - if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { - Some(new ContextCleaner(this)) - } else { - None + if (files != null) { + files.foreach(addFile) } - } - cleaner.foreach(_.start()) - setupAndStartListenerBus() - postEnvironmentUpdate() - postApplicationStart() + _executorMemory = _conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")) + .map(warnSparkMem)) + .map(Utils.memoryStringToMb) + .getOrElse(512) + + // Convert java options to env vars as a work around + // since we can't set env vars directly in sbt. + for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) + value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { + executorEnvs(envKey) = value + } + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" + executorEnvs ++= _conf.getExecutorEnv + executorEnvs("SPARK_USER") = sparkUser + + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + _heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + + // Create and start the scheduler + val (sched, ts) = SparkContext.createTaskScheduler(this, master) + _schedulerBackend = sched + _taskScheduler = ts + _dagScheduler = new DAGScheduler(this) + _heartbeatReceiver.send(TaskSchedulerIsSet) + + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's + // constructor + _taskScheduler.start() + + _applicationId = _taskScheduler.applicationId() + _conf.set("spark.app.id", _applicationId) + _env.blockManager.initialize(_applicationId) + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + // Attach the driver metrics servlet handler to the web ui after the metrics system is started. + metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + + _eventLogger = + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(_applicationId, _eventLogDir.get, _conf, _hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else { + None + } - private[spark] var checkpointDir: Option[String] = None + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + _executorAllocationManager = + if (dynamicAllocationEnabled) { + assert(supportDynamicAllocation, + "Dynamic allocation of executors is currently only supported in YARN mode") + Some(new ExecutorAllocationManager(this, listenerBus, _conf)) + } else { + None + } + _executorAllocationManager.foreach(_.start()) - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() + _cleaner = + if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + _cleaner.foreach(_.start()) + + setupAndStartListenerBus() + postEnvironmentUpdate() + postApplicationStart() + + // Post init + _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) + _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + } catch { + case NonFatal(e) => + logError("Error initializing SparkContext.", e) + try { + stop() + } catch { + case NonFatal(inner) => + logError("Error stopping SparkContext after init error.", inner) + } finally { + throw e + } } /** @@ -544,19 +631,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } - // Post init - taskScheduler.postStartHook() - - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - private def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() - // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. @@ -1146,7 +1220,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * this application is supported. This is currently only available for YARN. */ private[spark] def supportDynamicAllocation = - master.contains("yarn") || dynamicAllocationTesting + master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) /** * :: DeveloperApi :: @@ -1163,7 +1237,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * This is currently only supported in YARN mode. Return whether the request is received. */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, + assert(supportDynamicAllocation, "Requesting executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1403,28 +1477,40 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def stop() { // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. - if (!stopped.compareAndSet(false, true)) { logInfo("SparkContext already stopped.") return } - + postApplicationEnd() - ui.foreach(_.stop()) - env.metricsSystem.report() - metadataCleaner.cancel() - cleaner.foreach(_.stop()) - executorAllocationManager.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - listenerBus.stop() - eventLogger.foreach(_.stop()) - env.rpcEnv.stop(heartbeatReceiver) - progressBar.foreach(_.stop()) - taskScheduler = null + _ui.foreach(_.stop()) + if (env != null) { + env.metricsSystem.report() + } + if (metadataCleaner != null) { + metadataCleaner.cancel() + } + _cleaner.foreach(_.stop()) + _executorAllocationManager.foreach(_.stop()) + if (_dagScheduler != null) { + _dagScheduler.stop() + _dagScheduler = null + } + if (_listenerBusStarted) { + listenerBus.stop() + _listenerBusStarted = false + } + _eventLogger.foreach(_.stop()) + if (env != null && _heartbeatReceiver != null) { + env.rpcEnv.stop(_heartbeatReceiver) + } + _progressBar.foreach(_.stop()) + _taskScheduler = null // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) + if (_env != null) { + _env.stop() + SparkEnv.set(null) + } SparkContext.clearActiveContext() logInfo("Successfully stopped SparkContext") } @@ -1749,6 +1835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } listenerBus.start(this) + _listenerBusStarted = true } /** Post the application start event */ @@ -2152,7 +2239,7 @@ object SparkContext extends Logging { master match { case "local" => val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, 1) + val backend = new LocalBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) (backend, scheduler) @@ -2164,7 +2251,7 @@ object SparkContext extends Logging { throw new SparkException(s"Asked to run locally with $threadCount threads") } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) @@ -2174,7 +2261,7 @@ object SparkContext extends Logging { // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index b7ae9c1fc0a23..ae99432f5ce86 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -22,12 +22,13 @@ import java.net.URI private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], - val memoryPerSlave: Int, + val memoryPerExecutorMB: Int, val command: Command, var appUiUrl: String, val eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) - val eventLogCodec: Option[String] = None) + val eventLogCodec: Option[String] = None, + val coresPerExecutor: Option[Int] = None) extends Serializable { val user = System.getProperty("user.name", "") @@ -35,13 +36,13 @@ private[spark] class ApplicationDescription( def copy( name: String = name, maxCores: Option[Int] = maxCores, - memoryPerSlave: Int = memoryPerSlave, + memoryPerExecutorMB: Int = memoryPerExecutorMB, command: Command = command, appUiUrl: String = appUiUrl, eventLogDir: Option[URI] = eventLogDir, eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = new ApplicationDescription( - name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir, eventLogCodec) + name, maxCores, memoryPerExecutorMB, command, appUiUrl, eventLogDir, eventLogCodec) override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index dfc5b97e6a6c8..2954f932b4f41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -46,7 +46,7 @@ private[deploy] object JsonProtocol { ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ ("user" -> obj.desc.user) ~ - ("memoryperslave" -> obj.desc.memoryPerSlave) ~ + ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ ("duration" -> obj.duration) @@ -55,7 +55,7 @@ private[deploy] object JsonProtocol { def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ - ("memoryperslave" -> obj.memoryPerSlave) ~ + ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 60bc243ebf40a..296a0764b8baf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -406,6 +406,8 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options + OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 03ecf3fd99ec5..faa8780288ea3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -482,10 +482,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. | + | Spark standalone and YARN only: + | --executor-cores NUM Number of cores per executor. (Default: 1 in YARN mode, + | or all available cores on the worker in standalone mode) + | | YARN-only: | --driver-cores NUM Number of cores used by the driver, only in cluster mode | (Default: 1). - | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6e432d63c6b5a..3781b4e8c12bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -90,6 +90,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") ++ appTable + } else if (requestedIncomplete) { +

No incomplete applications found!

} else {

No completed applications found!

++

Did you specify the correct logging directory? diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index bc5b293379f2b..f59d550d4f3b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -75,9 +75,11 @@ private[deploy] class ApplicationInfo( } } - private[master] def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): - ExecutorDesc = { - val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) + private[master] def addExecutor( + worker: WorkerInfo, + cores: Int, + useID: Option[Int] = None): ExecutorDesc = { + val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerExecutorMB) executors(exec.id) = exec coresGranted += cores exec diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9a5d5877da86d..c5a6b1beac9be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -524,52 +524,28 @@ private[master] class Master( } /** - * Can an app use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the app on it (right now the standalone backend doesn't like having - * two executors on the same worker). - */ - private def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) - } - - /** - * Schedule the currently available resources among waiting apps. This method will be called - * every time a new app joins or resource availability changes. + * Schedule executors to be launched on the workers. + * + * There are two modes of launching executors. The first attempts to spread out an application's + * executors on as many workers as possible, while the second does the opposite (i.e. launch them + * on as few workers as possible). The former is usually better for data locality purposes and is + * the default. + * + * The number of cores assigned to each executor is configurable. When this is explicitly set, + * multiple executors from the same application may be launched on the same worker if the worker + * has enough cores and memory. Otherwise, each executor grabs all the cores available on the + * worker by default, in which case only one executor may be launched on each worker. */ - private def schedule() { - if (state != RecoveryState.ALIVE) { return } - - // First schedule drivers, they take strict precedence over applications - // Randomization helps balance drivers - val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) - val numWorkersAlive = shuffledAliveWorkers.size - var curPos = 0 - - for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers - // We assign workers to each waiting driver in a round-robin fashion. For each driver, we - // start from the last worker that was assigned a driver, and continue onwards until we have - // explored all alive workers. - var launched = false - var numWorkersVisited = 0 - while (numWorkersVisited < numWorkersAlive && !launched) { - val worker = shuffledAliveWorkers(curPos) - numWorkersVisited += 1 - if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { - launchDriver(worker, driver) - waitingDrivers -= driver - launched = true - } - curPos = (curPos + 1) % numWorkersAlive - } - } - + private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. if (spreadOutApps) { - // Try to spread out each app among all the nodes, until it has all its cores + // Try to spread out each app among all the workers, until it has all its cores for (app <- waitingApps if app.coresLeft > 0) { val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) @@ -582,32 +558,61 @@ private[master] class Master( pos = (pos + 1) % numUsable } // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable) { - if (assigned(pos) > 0) { - val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) - app.state = ApplicationState.RUNNING - } + for (pos <- 0 until numUsable if assigned(pos) > 0) { + allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } } else { - // Pack each app into as few nodes as possible until we've assigned all its cores + // Pack each app into as few workers as possible until we've assigned all its cores for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { for (app <- waitingApps if app.coresLeft > 0) { - if (canUse(app, worker)) { - val coresToUse = math.min(worker.coresFree, app.coresLeft) - if (coresToUse > 0) { - val exec = app.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) - app.state = ApplicationState.RUNNING - } - } + allocateWorkerResourceToExecutors(app, app.coresLeft, worker) + } + } + } + } + + /** + * Allocate a worker's resources to one or more executors. + * @param app the info of the application which the executors belong to + * @param coresToAllocate cores on this worker to be allocated to this application + * @param worker the worker info + */ + private def allocateWorkerResourceToExecutors( + app: ApplicationInfo, + coresToAllocate: Int, + worker: WorkerInfo): Unit = { + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) + var coresLeft = coresToAllocate + while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { + val exec = app.addExecutor(worker, coresPerExecutor) + coresLeft -= coresPerExecutor + launchExecutor(worker, exec) + app.state = ApplicationState.RUNNING + } + } + + /** + * Schedule the currently available resources among waiting apps. This method will be called + * every time a new app joins or resource availability changes. + */ + private def schedule(): Unit = { + if (state != RecoveryState.ALIVE) { return } + // Drivers take strict precedence over executors + val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers + for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { + for (driver <- waitingDrivers) { + if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { + launchDriver(worker, driver) + waitingDrivers -= driver } } } + startExecutorsOnWorkers() } - private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { + private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(masterUrl, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 761aa8f7b1ef6..273f077bd8f57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -94,7 +94,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

  • Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • Submit Date: {app.submitDate}
  • State: {app.state}
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 45412a35e9a7d..399f07399a0aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -208,8 +208,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.coresGranted} - - {Utils.megabytesToString(app.desc.memoryPerSlave)} + + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} {UIUtils.formatDate(app.submitDate)} {app.desc.user} diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 516f619529c48..1b5fdeba28ee2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,7 +21,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -60,8 +60,6 @@ private[spark] class Executor( private val conf = env.conf - @volatile private var isStopped = false - // No ip or host:port - just hostname Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -114,6 +112,10 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + // Executor for the heartbeat task. + private val heartbeater = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("driver-heartbeater")) + startDriverHeartbeater() def launchTask( @@ -138,7 +140,8 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() env.rpcEnv.stop(executorEndpoint) - isStopped = true + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() if (!isLocal) { env.stop() @@ -432,23 +435,17 @@ private[spark] class Executor( } /** - * Starts a thread to report heartbeat and partial metrics for active tasks to driver. - * This thread stops running when the executor is stopped. + * Schedules a task to report heartbeat and partial metrics for active tasks to driver. */ private def startDriverHeartbeater(): Unit = { val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") - val thread = new Thread() { - override def run() { - // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(intervalMs + (math.random * intervalMs).asInstanceOf[Int]) - while (!isStopped) { - reportHeartBeat() - Thread.sleep(intervalMs) - } - } + + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) } - thread.setDaemon(true) - thread.setName("driver-heartbeater") - thread.start() + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8e3c30fc3d781..5a74c13b38bf7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -86,11 +86,11 @@ private[nio] class ConnectionManager( conf.get("spark.network.timeout", "120s")) // Get the thread counts from the Spark Configuration. - // + // // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is + // + // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" // parameter is necessary. private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) @@ -989,6 +989,7 @@ private[nio] class ConnectionManager( def stop() { ackTimeoutMonitor.stop() + selector.wakeup() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index e259867c14040..f2c1c86af767e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -284,7 +284,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) - private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + private[this] val defaultAskTimeout = conf.getLong("spark.akka.askTimeout", 30) seconds /** * return the address for the [[RpcEndpointRef]] @@ -304,7 +304,8 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) + def sendWithReply[T: ClassTag](message: Any): Future[T] = + sendWithReply(message, defaultAskTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to @@ -327,7 +328,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 2362cc7240039..13a52d836f32f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -142,11 +142,10 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, SPECULATION_INTERVAL_MS milliseconds) { Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - } + }(sc.env.actorSystem.dispatcher) } } @@ -394,7 +393,7 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.size > 0) { + if (activeTaskSets.nonEmpty) { // Have each task set throw a SparkException with the error for ((taskSetId, manager) <- activeTaskSets) { try { @@ -407,8 +406,7 @@ private[spark] class TaskSchedulerImpl( // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) + throw new SparkException(s"Exiting due to error from cluster scheduler: $message") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 7eb3fdc19b5b8..ccf1dc5af6120 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -82,12 +82,11 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec) - + val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, + command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() - waitForRegistration() } @@ -119,9 +118,12 @@ private[spark] class SparkDeploySchedulerBackend( notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) - scheduler.error(reason) - // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + try { + scheduler.error(reason) + } finally { + // Ensure the application terminates, as we can no longer run jobs. + sc.stop() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 70a477a6895cc..50ba0b9d5a612 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -20,12 +20,12 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer import java.util.concurrent.{Executors, TimeUnit} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.util.Utils private case class ReviveOffers() @@ -71,11 +71,15 @@ private[spark] class LocalEndpoint( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopExecutor => executor.stop() + context.reply(true) } + def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) val tasks = scheduler.resourceOffers(offers).flatten @@ -104,8 +108,11 @@ private[spark] class LocalEndpoint( * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { +private[spark] class LocalBackend( + conf: SparkConf, + scheduler: TaskSchedulerImpl, + val totalCores: Int) + extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis var localEndpoint: RpcEndpointRef = null @@ -116,7 +123,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def stop() { - localEndpoint.send(StopExecutor) + localEndpoint.sendWithReply(StopExecutor) } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala index 332d0cbb2dc0c..81a7cbde01ce5 100644 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -43,7 +43,13 @@ private[spark] trait ActorLogReceive { private val _receiveWithLogging = receiveWithLogging - override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + override def isDefinedAt(o: Any): Boolean = { + val handled = _receiveWithLogging.isDefinedAt(o) + if (!handled) { + log.debug(s"Received unexpected actor system event: $o") + } + handled + } override def apply(o: Any): Unit = { if (log.isDebugEnabled) { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index d60b8b9a31a9b..a725767d08cc2 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -19,9 +19,12 @@ package org.apache.spark.util import java.util.concurrent.CopyOnWriteArrayList +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.Logging +import org.apache.spark.scheduler.SparkListener /** * An event bus which posts events to its listeners. @@ -64,4 +67,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ def onPostEvent(listener: L, event: E): Unit + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { + val c = implicitly[ClassTag[T]].runtimeClass + listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + } + } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 6b3049b28cd5e..22acc270b983e 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -56,19 +56,13 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") intercept[SparkException] { contexts += new SparkContext(conf1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") intercept[SparkException] { contexts += new SparkContext(conf2) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 9cdb42814ca32..c93d16f8a1586 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.net.URL +import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source @@ -65,16 +66,17 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { new MySparkConf().setAll(getAll) } } - val conf = new MySparkConf() + val conf = new MySparkConf().set( + "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - val listener = new SaveExecutorInfo - sc.addSparkListener(listener) - // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) info.logUrlMap.values.foreach { logUrl => @@ -82,12 +84,12 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { } } } +} - private class SaveExecutorInfo extends SparkListener { - val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() +private[spark] class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() - override def onExecutorAdded(executor: SparkListenerExecutorAdded) { - addedExecutorInfos(executor.executorId) = executor.executorInfo - } + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo } } diff --git a/docs/configuration.md b/docs/configuration.md index 7169ec295ef7f..d9e9e67026cbb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -723,6 +723,17 @@ Apart from these, the following properties are also available, and may be useful this duration will be cleared as well. + + spark.executor.cores + 1 in YARN mode, all the available cores on the worker in standalone mode. + + The number of cores to use on each executor. For YARN and standalone mode only. + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one + executor per application will run on each worker. + + spark.default.parallelism diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 332618edf0c55..03500867df70f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1371,7 +1371,10 @@ the Data Sources API. The following options are supported: These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. - partitionColumn must be a numeric column from the table in question. + partitionColumn must be a numeric column from the table in question. Notice + that lowerBound and upperBound are just used to decide the + partition stride, not for filtering the rows in table. So all rows in the table will be + partitioned and returned. diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 1c8a20bf8f1ae..11a8cf09533ce 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -41,7 +41,7 @@ object DirectKafkaWordCount { | is a list of one or more Kafka brokers | is a list of one or more kafka topics to consume from | - """".stripMargin) + """.stripMargin) System.exit(1) } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 9f22d40c15f3f..6d8b806569dfd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -65,7 +65,7 @@ object PowerIterationClusteringExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("PIC Circles") { + val parser = new OptionParser[Params]("PowerIterationClusteringExample") { head("PowerIterationClusteringExample: an example PIC app using concentric circles.") opt[Int]('k', "k") .text(s"number of circles (/clusters), default: ${defaultParams.k}") @@ -76,9 +76,9 @@ object PowerIterationClusteringExample { opt[Int]("maxIterations") .text(s"number of iterations, default: ${defaultParams.maxIterations}") .action((x, c) => c.copy(maxIterations = x)) - opt[Int]('r', "r") + opt[Double]('r', "r") .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") - .action((x, c) => c.copy(numPoints = x)) + .action((x, c) => c.copy(outerRadius = x)) } parser.parse(args, defaultParams).map { params => @@ -154,3 +154,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } + diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a0b8a0c565210..a1b4a12e5d6a0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -23,10 +23,9 @@ import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskC import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator -import java.util.Properties import kafka.api.{FetchRequestBuilder, FetchResponse} import kafka.common.{ErrorMapping, TopicAndPartition} -import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import kafka.consumer.SimpleConsumer import kafka.message.{MessageAndMetadata, MessageAndOffset} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index d8279145d8e90..b8f02b961113d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -186,12 +186,24 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - final String assembly = AbstractCommandBuilder.class.getProtectionDomain().getCodeSource(). - getLocation().getPath(); + // We can't rely on the ENV_SPARK_ASSEMBLY variable to be set. Certain situations, such as + // when running unit tests, or user code that embeds Spark and creates a SparkContext + // with a local or local-cluster master, will cause this code to be called from an + // environment where that env variable is not guaranteed to exist. + // + // For the testing case, we rely on the test code to set and propagate the test classpath + // appropriately. + // + // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. + // That duplicates some of the code in the shell scripts that look for the assembly, though. + String assembly = getenv(ENV_SPARK_ASSEMBLY); + if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + assembly = findAssembly(); + } addToClassPath(cp, assembly); - // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only - // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate + // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only + // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive File libdir; if (new File(sparkHome, "RELEASE").isFile()) { @@ -299,6 +311,30 @@ String getenv(String key) { return firstNonEmpty(childEnv.get(key), System.getenv(key)); } + private String findAssembly() { + String sparkHome = getSparkHome(); + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + } else { + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); + } + + final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); + FileFilter filter = new FileFilter() { + @Override + public boolean accept(File file) { + return file.isFile() && re.matcher(file.getName()).matches(); + } + }; + File[] assemblies = libdir.listFiles(filter); + checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); + checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); + return assemblies[0].getAbsolutePath(); + } + private String getConfDir() { String confDir = getenv("SPARK_CONF_DIR"); return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index f4ebc25bdd32b..8028e42ffb483 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -30,6 +30,7 @@ class CommandBuilderUtils { static final String DEFAULT_MEM = "512m"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; + static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 626116a9e7477..97043a76cc612 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -98,7 +98,7 @@ public void testShellCliParser() throws Exception { parser.NAME, "appName"); - List args = new SparkSubmitCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List args = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); List expected = Arrays.asList("spark-shell", "--app-arg", "bar", "--app-switch"); assertEquals(expected, args.subList(args.size() - expected.size(), args.size())); } @@ -110,7 +110,7 @@ public void testAlternateSyntaxParsing() throws Exception { parser.MASTER + "=foo", parser.DEPLOY_MODE + "=bar"); - List cmd = new SparkSubmitCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List cmd = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); assertEquals("org.my.Class", findArgValue(cmd, parser.CLASS)); assertEquals("foo", findArgValue(cmd, parser.MASTER)); assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); @@ -153,7 +153,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception { String deployMode = isDriver ? "client" : "cluster"; SparkSubmitCommandBuilder launcher = - new SparkSubmitCommandBuilder(Collections.emptyList()); + newCommandBuilder(Collections.emptyList()); launcher.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); launcher.master = "yarn"; @@ -273,10 +273,15 @@ private boolean findInStringList(String list, String sep, String needle) { return contains(needle, list.split(sep)); } - private List buildCommand(List args, Map env) throws Exception { + private SparkSubmitCommandBuilder newCommandBuilder(List args) { SparkSubmitCommandBuilder builder = new SparkSubmitCommandBuilder(args); builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); - return builder.buildCommand(env); + builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_ASSEMBLY, "dummy"); + return builder; + } + + private List buildCommand(List args, Map env) throws Exception { + return newCommandBuilder(args).buildCommand(env); } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index a50090671ae48..a1d49095c24ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -25,7 +25,7 @@ import java.util.UUID private[ml] trait Identifiable extends Serializable { /** - * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * A unique id for the object. The default implementation concatenates the class name, "_", and 8 * random hex chars. */ private[ml] val uid: String = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 328dbe2ce11fa..4ef171f4f0419 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -227,7 +227,7 @@ object Vectors { * @param elements vector elements in (index, value) pairs. */ def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0) + require(size > 0, "The size of the requested sparse vector must be greater than 0.") val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 @@ -235,7 +235,8 @@ object Vectors { require(prev < i, s"Found duplicate indices: $i.") prev = i } - require(prev < size) + require(prev < size, s"You may not write an element to index $prev because the declared " + + s"size of your vector is $size") new SparseVector(size, indices.toArray, values.toArray) } @@ -309,7 +310,8 @@ object Vectors { * @return norm in L^p^ space. */ def norm(vector: Vector, p: Double): Double = { - require(p >= 1.0) + require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + + s"You specified p=$p.") val values = vector match { case DenseVector(vs) => vs case SparseVector(n, ids, vs) => vs @@ -360,7 +362,8 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { - require(v1.size == v2.size, "vector dimension mismatch") + require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + + s"=${v2.size}.") var squaredDistance = 0.0 (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => @@ -518,7 +521,9 @@ class SparseVector( val indices: Array[Int], val values: Array[Double]) extends Vector { - require(indices.length == values.length) + require(indices.length == values.length, "Sparse vectors require that the dimension of the" + + s" indices match the dimension of the values. You provided ${indices.size} indices and " + + s" ${values.size} values.") override def toString: String = "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 8f9ab687c05cb..641b64b42a5e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.param +import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} + /** A subclass of Params for testing. */ -class TestParams extends Params { +class TestParams extends Params with HasMaxIter with HasInputCol { - val maxIter = new IntParam(this, "maxIter", "max number of iterations") def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = getOrDefault(maxIter) - - val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = getOrDefault(inputCol) setDefault(maxIter -> 10) diff --git a/pom.xml b/pom.xml index 0b8d664dbfe1b..bcc2f57f1af5d 100644 --- a/pom.xml +++ b/pom.xml @@ -156,7 +156,7 @@ 3.6.3 1.8.8 2.4.4 - 1.1.1.6 + 1.1.1.7 1.1.2 ${java.home} @@ -1447,7 +1447,7 @@ org.scalastyle scalastyle-maven-plugin - 0.4.0 + 0.7.0 false true @@ -1456,13 +1456,12 @@ ${basedir}/src/main/scala ${basedir}/src/test/scala scalastyle-config.xml - scalastyle-output.xml + ${basedir}/target/scalastyle-output.xml ${project.build.sourceEncoding} ${project.reporting.outputEncoding} - package check diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f51f4b58f97a..09b4976d10c26 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -import java.io.File +import java.io._ import scala.util.Properties import scala.collection.JavaConversions._ @@ -166,6 +166,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Package pyspark artifacts in the main assembly. */ + enable(PySparkAssembly.settings)(assembly) + /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) @@ -316,6 +319,7 @@ object Hive { } object Assembly { + import sbtassembly.AssemblyUtils._ import sbtassembly.Plugin._ import AssemblyKeys._ @@ -347,6 +351,60 @@ object Assembly { ) } +object PySparkAssembly { + import sbtassembly.Plugin._ + import AssemblyKeys._ + + lazy val settings = Seq( + unmanagedJars in Compile += { BuildCommons.sparkHome / "python/lib/py4j-0.8.2.1-src.zip" }, + // Use a resource generator to copy all .py files from python/pyspark into a managed directory + // to be included in the assembly. We can't just add "python/" to the assembly's resource dir + // list since that will copy unneeded / unwanted files. + resourceGenerators in Compile <+= resourceManaged in Compile map { outDir: File => + val dst = new File(outDir, "pyspark") + if (!dst.isDirectory()) { + require(dst.mkdirs()) + } + + val src = new File(BuildCommons.sparkHome, "python/pyspark") + copy(src, dst) + } + ) + + private def copy(src: File, dst: File): Seq[File] = { + src.listFiles().flatMap { f => + val child = new File(dst, f.getName()) + if (f.isDirectory()) { + child.mkdir() + copy(f, child) + } else if (f.getName().endsWith(".py")) { + var in: Option[FileInputStream] = None + var out: Option[FileOutputStream] = None + try { + in = Some(new FileInputStream(f)) + out = Some(new FileOutputStream(child)) + + val bytes = new Array[Byte](1024) + var read = 0 + while (read >= 0) { + read = in.get.read(bytes) + if (read > 0) { + out.get.write(bytes, 0, read) + } + } + + Some(child) + } finally { + in.foreach(_.close()) + out.foreach(_.close()) + } + } else { + None + } + } + } +} + object Unidoc { import BuildCommons._ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 7f42de531f3b4..d7bc09fd77adb 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxIter=100, regParam=0.1) """ super(LogisticRegression, self).__init__() + self._setDefault(maxIter=100, regParam=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def _create_model(self, java_model): return LogisticRegressionModel(java_model) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1cfcd019dfb18..263fe2a5bcc41 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): _java_class = "org.apache.spark.ml.feature.Tokenizer" @keyword_only - def __init__(self, inputCol="input", outputCol="output"): + def __init__(self, inputCol=None, outputCol=None): """ - __init__(self, inputCol="input", outputCol="output") + __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol="input", outputCol="output"): + def setParams(self, inputCol=None, outputCol=None): """ setParams(self, inputCol="input", outputCol="output") Sets params for this Tokenizer. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) @inherit_doc @@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): _java_class = "org.apache.spark.ml.feature.HashingTF" @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) Sets params for this HashingTF. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) if __name__ == "__main__": diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index e3a53dd780c4c..5c62620562a84 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -25,23 +25,21 @@ class Param(object): """ - A param with self-contained documentation and optionally default value. + A param with self-contained documentation. """ - def __init__(self, parent, name, doc, defaultValue=None): - if not isinstance(parent, Identifiable): - raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + def __init__(self, parent, name, doc): + if not isinstance(parent, Params): + raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__) self.parent = parent self.name = str(name) self.doc = str(doc) - self.defaultValue = defaultValue def __str__(self): - return str(self.parent) + "-" + self.name + return str(self.parent) + "__" + self.name def __repr__(self): - return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ - (self.parent, self.name, self.doc, self.defaultValue) + return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) class Params(Identifiable): @@ -52,26 +50,128 @@ class Params(Identifiable): __metaclass__ = ABCMeta - def __init__(self): - super(Params, self).__init__() - #: embedded param map - self.paramMap = {} + #: internal param map for user-supplied values param map + paramMap = {} + + #: internal param map for default values + defaultParamMap = {} @property def params(self): """ - Returns all params. The default implementation uses - :py:func:`dir` to get all attributes of type + Returns all params ordered by name. The default implementation + uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ return filter(lambda attr: isinstance(attr, Param), [getattr(self, x) for x in dir(self) if x != "params"]) - def _merge_params(self, params): - paramMap = self.paramMap.copy() - paramMap.update(params) + def _explain(self, param): + """ + Explains a single param and returns its name, doc, and optional + default value and user-supplied value in a string. + """ + param = self._resolveParam(param) + values = [] + if self.isDefined(param): + if param in self.defaultParamMap: + values.append("default: %s" % self.defaultParamMap[param]) + if param in self.paramMap: + values.append("current: %s" % self.paramMap[param]) + else: + values.append("undefined") + valueStr = "(" + ", ".join(values) + ")" + return "%s: %s %s" % (param.name, param.doc, valueStr) + + def explainParams(self): + """ + Returns the documentation of all params with their optionally + default values and user-supplied values. + """ + return "\n".join([self._explain(param) for param in self.params]) + + def getParam(self, paramName): + """ + Gets a param by its name. + """ + param = getattr(self, paramName) + if isinstance(param, Param): + return param + else: + raise ValueError("Cannot find param with name %s." % paramName) + + def isSet(self, param): + """ + Checks whether a param is explicitly set by user. + """ + param = self._resolveParam(param) + return param in self.paramMap + + def hasDefault(self, param): + """ + Checks whether a param has a default value. + """ + param = self._resolveParam(param) + return param in self.defaultParamMap + + def isDefined(self, param): + """ + Checks whether a param is explicitly set by user or has a default value. + """ + return self.isSet(param) or self.hasDefault(param) + + def getOrDefault(self, param): + """ + Gets the value of a param in the user-supplied param map or its + default value. Raises an error if either is set. + """ + if isinstance(param, Param): + if param in self.paramMap: + return self.paramMap[param] + else: + return self.defaultParamMap[param] + elif isinstance(param, str): + return self.getOrDefault(self.getParam(param)) + else: + raise KeyError("Cannot recognize %r as a param." % param) + + def extractParamMap(self, extraParamMap={}): + """ + Extracts the embedded default param values and user-supplied + values, and then merges them with extra values from input into + a flat param map, where the latter value is used if there exist + conflicts, i.e., with ordering: default param values < + user-supplied values < extraParamMap. + :param extraParamMap: extra param values + :return: merged param map + """ + paramMap = self.defaultParamMap.copy() + paramMap.update(self.paramMap) + paramMap.update(extraParamMap) return paramMap + def _shouldOwn(self, param): + """ + Validates that the input param belongs to this Params instance. + """ + if param.parent is not self: + raise ValueError("Param %r does not belong to %r." % (param, self)) + + def _resolveParam(self, param): + """ + Resolves a param and validates the ownership. + :param param: param name or the param instance, which must + belong to this Params instance + :return: resolved param instance + """ + if isinstance(param, Param): + self._shouldOwn(param) + return param + elif isinstance(param, str): + return self.getParam(param) + else: + raise ValueError("Cannot resolve %r as a param." % param) + @staticmethod def _dummy(): """ @@ -81,10 +181,18 @@ def _dummy(): dummy.uid = "undefined" return dummy - def _set_params(self, **kwargs): + def _set(self, **kwargs): """ - Sets params. + Sets user-supplied params. """ for param, value in kwargs.iteritems(): self.paramMap[getattr(self, param)] = value return self + + def _setDefault(self, **kwargs): + """ + Sets default params. + """ + for param, value in kwargs.iteritems(): + self.defaultParamMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_shared_params_code_gen.py similarity index 70% rename from python/pyspark/ml/param/_gen_shared_params.py rename to python/pyspark/ml/param/_shared_params_code_gen.py index 5eb81106f116c..55f422497672f 100644 --- a/python/pyspark/ml/param/_gen_shared_params.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -32,29 +32,34 @@ # limitations under the License. #""" +# Code generator for shared params (shared.py). Run under this folder with: +# python _shared_params_code_gen.py > shared.py -def _gen_param_code(name, doc, defaultValue): + +def _gen_param_code(name, doc, defaultValueStr): """ Generates Python code for a shared param class. :param name: param name :param doc: param doc - :param defaultValue: string representation of the param + :param defaultValueStr: string representation of the default value :return: code string """ # TODO: How to correctly inherit instance attributes? template = '''class Has$Name(Params): """ - Params with $name. + Mixin for param $name: $doc. """ # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + $name = Param(Params._dummy(), "$name", "$doc") def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc", $defaultValue) + self.$name = Param(self, "$name", "$doc") + if $defaultValueStr is not None: + self._setDefault($name=$defaultValueStr) def set$Name(self, value): """ @@ -67,32 +72,29 @@ def get$Name(self): """ Gets the value of $name or its default value. """ - if self.$name in self.paramMap: - return self.paramMap[self.$name] - else: - return self.$name.defaultValue''' + return self.getOrDefault(self.$name)''' - upperCamelName = name[0].upper() + name[1:] + Name = name[0].upper() + name[1:] return template \ .replace("$name", name) \ - .replace("$Name", upperCamelName) \ + .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValue", defaultValue) + .replace("$defaultValueStr", str(defaultValueStr)) if __name__ == "__main__": print header - print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n" print "from pyspark.ml.param import Param, Params\n\n" shared = [ - ("maxIter", "max number of iterations", "100"), - ("regParam", "regularization constant", "0.1"), + ("maxIter", "max number of iterations", None), + ("regParam", "regularization constant", None), ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), - ("inputCol", "input column name", "'input'"), - ("outputCol", "output column name", "'output'"), - ("numFeatures", "number of features", "1 << 18")] + ("inputCol", "input column name", None), + ("outputCol", "output column name", None), + ("numFeatures", "number of features", None)] code = [] - for name, doc, defaultValue in shared: - code.append(_gen_param_code(name, doc, defaultValue)) + for name, doc, defaultValueStr in shared: + code.append(_gen_param_code(name, doc, defaultValueStr)) print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 586822f2de423..13b6749998ad0 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -15,23 +15,25 @@ # limitations under the License. # -# DO NOT MODIFY. The code is generated by _gen_shared_params.py. +# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. from pyspark.ml.param import Param, Params class HasMaxIter(Params): """ - Params with maxIter. + Mixin for param maxIter: max number of iterations. """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations") def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations - self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + self.maxIter = Param(self, "maxIter", "max number of iterations") + if None is not None: + self._setDefault(maxIter=None) def setMaxIter(self, value): """ @@ -44,24 +46,23 @@ def getMaxIter(self): """ Gets the value of maxIter or its default value. """ - if self.maxIter in self.paramMap: - return self.paramMap[self.maxIter] - else: - return self.maxIter.defaultValue + return self.getOrDefault(self.maxIter) class HasRegParam(Params): """ - Params with regParam. + Mixin for param regParam: regularization constant. """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + regParam = Param(Params._dummy(), "regParam", "regularization constant") def __init__(self): super(HasRegParam, self).__init__() #: param for regularization constant - self.regParam = Param(self, "regParam", "regularization constant", 0.1) + self.regParam = Param(self, "regParam", "regularization constant") + if None is not None: + self._setDefault(regParam=None) def setRegParam(self, value): """ @@ -74,24 +75,23 @@ def getRegParam(self): """ Gets the value of regParam or its default value. """ - if self.regParam in self.paramMap: - return self.paramMap[self.regParam] - else: - return self.regParam.defaultValue + return self.getOrDefault(self.regParam) class HasFeaturesCol(Params): """ - Params with featuresCol. + Mixin for param featuresCol: features column name. """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + featuresCol = Param(Params._dummy(), "featuresCol", "features column name") def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name - self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + self.featuresCol = Param(self, "featuresCol", "features column name") + if 'features' is not None: + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ @@ -104,24 +104,23 @@ def getFeaturesCol(self): """ Gets the value of featuresCol or its default value. """ - if self.featuresCol in self.paramMap: - return self.paramMap[self.featuresCol] - else: - return self.featuresCol.defaultValue + return self.getOrDefault(self.featuresCol) class HasLabelCol(Params): """ - Params with labelCol. + Mixin for param labelCol: label column name. """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + labelCol = Param(Params._dummy(), "labelCol", "label column name") def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name - self.labelCol = Param(self, "labelCol", "label column name", 'label') + self.labelCol = Param(self, "labelCol", "label column name") + if 'label' is not None: + self._setDefault(labelCol='label') def setLabelCol(self, value): """ @@ -134,24 +133,23 @@ def getLabelCol(self): """ Gets the value of labelCol or its default value. """ - if self.labelCol in self.paramMap: - return self.paramMap[self.labelCol] - else: - return self.labelCol.defaultValue + return self.getOrDefault(self.labelCol) class HasPredictionCol(Params): """ - Params with predictionCol. + Mixin for param predictionCol: prediction column name. """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name - self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + self.predictionCol = Param(self, "predictionCol", "prediction column name") + if 'prediction' is not None: + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ @@ -164,24 +162,23 @@ def getPredictionCol(self): """ Gets the value of predictionCol or its default value. """ - if self.predictionCol in self.paramMap: - return self.paramMap[self.predictionCol] - else: - return self.predictionCol.defaultValue + return self.getOrDefault(self.predictionCol) class HasInputCol(Params): """ - Params with inputCol. + Mixin for param inputCol: input column name. """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + inputCol = Param(Params._dummy(), "inputCol", "input column name") def __init__(self): super(HasInputCol, self).__init__() #: param for input column name - self.inputCol = Param(self, "inputCol", "input column name", 'input') + self.inputCol = Param(self, "inputCol", "input column name") + if None is not None: + self._setDefault(inputCol=None) def setInputCol(self, value): """ @@ -194,24 +191,23 @@ def getInputCol(self): """ Gets the value of inputCol or its default value. """ - if self.inputCol in self.paramMap: - return self.paramMap[self.inputCol] - else: - return self.inputCol.defaultValue + return self.getOrDefault(self.inputCol) class HasOutputCol(Params): """ - Params with outputCol. + Mixin for param outputCol: output column name. """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + outputCol = Param(Params._dummy(), "outputCol", "output column name") def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name - self.outputCol = Param(self, "outputCol", "output column name", 'output') + self.outputCol = Param(self, "outputCol", "output column name") + if None is not None: + self._setDefault(outputCol=None) def setOutputCol(self, value): """ @@ -224,24 +220,23 @@ def getOutputCol(self): """ Gets the value of outputCol or its default value. """ - if self.outputCol in self.paramMap: - return self.paramMap[self.outputCol] - else: - return self.outputCol.defaultValue + return self.getOrDefault(self.outputCol) class HasNumFeatures(Params): """ - Params with numFeatures. + Mixin for param numFeatures: number of features. """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + numFeatures = Param(Params._dummy(), "numFeatures", "number of features") def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features - self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + self.numFeatures = Param(self, "numFeatures", "number of features") + if None is not None: + self._setDefault(numFeatures=None) def setNumFeatures(self, value): """ @@ -254,7 +249,4 @@ def getNumFeatures(self): """ Gets the value of numFeatures or its default value. """ - if self.numFeatures in self.paramMap: - return self.paramMap[self.numFeatures] - else: - return self.numFeatures.defaultValue + return self.getOrDefault(self.numFeatures) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 83880a5afcd1d..d94ecfff09f66 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -124,10 +124,10 @@ def setParams(self, stages=[]): Sets params for Pipeline. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def fit(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) stages = paramMap[self.stages] for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): @@ -164,7 +164,7 @@ def __init__(self, transformers): self.transformers = transformers def transform(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for t in self.transformers: dataset = t.transform(dataset, paramMap) return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b627c2b4e930b..3a42bcf723894 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -33,6 +33,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasMaxIter, HasInputCol from pyspark.ml.pipeline import Transformer, Estimator, Pipeline @@ -46,7 +47,7 @@ class MockTransformer(Transformer): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None @@ -62,7 +63,7 @@ class MockEstimator(Estimator): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None self.model = None @@ -111,5 +112,52 @@ def test_pipeline(self): self.assertEqual(6, dataset.index) +class TestParams(HasMaxIter, HasInputCol): + """ + A subclass of Params mixed with HasMaxIter and HasInputCol. + """ + + def __init__(self): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + + +class ParamTests(PySparkTestCase): + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations") + self.assertTrue(maxIter.parent is testParams) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter]) + + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEquals(testParams.getMaxIter(), 100) + + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + self.assertEquals( + testParams.explainParams(), + "\n".join(["inputCol: input column name (undefined)", + "maxIter: max number of iterations (default: 10, current: 100)"])) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6f7f39c40eb5a..d3cb100a9efa5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -40,8 +40,8 @@ class Identifiable(object): def __init__(self): #: A unique id for the object. The default implementation - #: concatenates the class name, "-", and 8 random hex chars. - self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + #: concatenates the class name, "_", and 8 random hex chars. + self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] def __repr__(self): return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 31a66b3d2f730..394f23c5e9b12 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -64,7 +64,7 @@ def _transfer_params_to_java(self, params, java_obj): :param params: additional params (overwriting embedded values) :param java_obj: Java object to receive the params """ - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: java_obj.set(param.name, paramMap[param]) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c9ac95d117574..93e658eded9e2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1197,7 +1197,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.partitions().size() + totalParts = self.getNumPartitions() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1260,7 +1260,7 @@ def isEmpty(self): >>> sc.parallelize([1]).isEmpty() False """ - return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + return self.getNumPartitions() == 0 or len(self.take(1)) == 0 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -2235,11 +2235,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): ser = CloudPickleSerializer() pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - # tracking the life cycle by obj - if obj is not None: - obj._broadcast = broadcast broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) @@ -2294,12 +2292,9 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None - self._broadcast = None - def __del__(self): - if self._broadcast: - self._broadcast.unpersist() - self._broadcast = None + def getNumPartitions(self): + return self._prev_jrdd.partitions().size() @property def _jrdd(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ef91a9c4f522d..f2c3b74a185cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -456,7 +456,7 @@ def join(self, other, joinExprs=None, joinType=None): One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] """ if joinExprs is None: @@ -637,9 +637,9 @@ def groupBy(self, *cols): >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) @@ -867,11 +867,11 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(MIN(age)=5), Row(MIN(age)=2)] + [Row(MIN(age)=2), Row(MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b938b9ce12395..ee67e80d539f8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -550,10 +550,8 @@ def test_large_closure(self): data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) self.assertEquals(N, rdd.first()) - self.assertTrue(rdd._broadcast is not None) - rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) - self.assertEqual(1, rdd.first()) - self.assertTrue(rdd._broadcast is None) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index d794f034f5578..ac8a782976465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{StructType, DateUtils} +import org.apache.spark.sql.types.StructType object Row { /** @@ -257,6 +257,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ + // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 91976fef6dc0d..d4f9fdacda4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -77,6 +77,9 @@ object CatalystTypeConverters { } new GenericRowWithSchema(ar, structType) + case (d: String, _) => + UTF8String(d) + case (d: BigDecimal, _) => Decimal(d) @@ -175,6 +178,11 @@ object CatalystTypeConverters { case other => other } + case dataType: StringType => (item: Any) => extractOption(item) match { + case s: String => UTF8String(s) + case other => other + } + case _ => (item: Any) => extractOption(item) match { case d: BigDecimal => Decimal(d) @@ -184,6 +192,26 @@ object CatalystTypeConverters { } } + /** + * Converts Scala objects to catalyst rows / types. + * + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. + */ + def convertToCatalyst(a: Any): Any = a match { + case s: String => UTF8String(s) + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case seq: Seq[Any] => seq.map(convertToCatalyst) + case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + case other => other + } + /** * Converts Catalyst types used internally in rows to standard Scala types * This method is slow, and for batch conversion you should be using converter @@ -211,6 +239,9 @@ object CatalystTypeConverters { case (i: Int, DateType) => DateUtils.toJavaDate(i) + case (s: UTF8String, StringType) => + s.toString() + case (other, _) => other } @@ -262,6 +293,12 @@ object CatalystTypeConverters { case other => other } + case StringType => + (item: Any) => item match { + case s: UTF8String => s.toString() + case other => other + } + case other => (item: Any) => item } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 01d5c1512201a..d9521953cad73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -138,6 +138,7 @@ trait ScalaReflection { // The data type can be determined without ambiguity. case obj: BooleanType.JvmType => BooleanType case obj: BinaryType.JvmType => BinaryType + case obj: String => StringType case obj: StringType.JvmType => StringType case obj: ByteType.JvmType => ByteType case obj: ShortType.JvmType => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index bc8d3751f6616..0af969cc5cc67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -121,14 +121,14 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } protected lazy val start: Parser[LogicalPlan] = - ( (select | ("(" ~> select <~ ")")) * - ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } - | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } - | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} - | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) - | insert - | cte + start1 | insert | cte + + protected lazy val start1: Parser[LogicalPlan] = + (select | ("(" ~> select <~ ")")) * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) protected lazy val select: Parser[LogicalPlan] = @@ -159,7 +159,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } protected lazy val cte: Parser[LogicalPlan] = - WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start <~ ")"), ",") ~ start ^^ { + WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ { case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) } @@ -381,13 +381,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ UnresolvedAttribute + | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString(".")) + case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b68b0df35f48..cb49e5ad5586f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -297,14 +297,15 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { - case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) && + case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && + resolver(nameParts(0), VirtualColumn.groupingIdName) && q.isInstanceOf[GroupingAnalytics] => // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid - case u @ UnresolvedAttribute(name) => + case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => @@ -383,12 +384,12 @@ class Analyzer( child: LogicalPlan, grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { // Find any attributes that remain unresolved in the sort. - val unresolved: Seq[String] = - ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + val unresolved: Seq[Seq[String]] = + ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts }) // Create a map from name, to resolved attributes, when the desired name can be found // prior to the projection. - val resolved: Map[String, NamedExpression] = + val resolved: Map[Seq[String], NamedExpression] = unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap // Construct a set that contains all of the attributes that we need to evaluate the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fa02111385c06..1155dac28fc78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -46,8 +46,12 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => if (operator.childrenResolved) { + val nameParts = a match { + case UnresolvedAttribute(nameParts) => nameParts + case _ => Seq(a.name) + } // Throw errors for specific problems with get field. - operator.resolveChildren(a.name, resolver, throwErrors = true) + operator.resolveChildren(nameParts, resolver, throwErrors = true) } val from = operator.inputSet.map(_.name).mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 3aeb964994d37..35c7f00d4e42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -115,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal.create("NaN", StringType) + val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -563,6 +563,10 @@ trait HiveTypeCoercion { case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + // Compatible with Hive + case Substring(e, start, len) if e.dataType != StringType => + Substring(Cast(e, StringType), start, len) + // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 300e9ba187bc5..3f567e3e8b2a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,7 +49,12 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { +case class UnresolvedAttribute(nameParts: Seq[String]) + extends Attribute with trees.LeafNode[Expression] { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") @@ -59,7 +64,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this - override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -68,6 +73,11 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def toString: String = s"'$name" } +object UnresolvedAttribute { + def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) +} + case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 31f1a5fdc7e53..adf941ab2a45f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ @@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Int](_, d => DateUtils.toString(d)) - case TimestampType => buildCast[Timestamp](_, timestampToString) - case _ => buildCast[Any](_, _.toString) + case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) + case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) + case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String(o.toString)) } // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { - case StringType => buildCast[String](_, _.getBytes("UTF-8")) + case StringType => buildCast[UTF8String](_, _.getBytes) } // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, _.length() != 0) + buildCast[UTF8String](_, _.length() != 0) case TimestampType => buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => { + buildCast[UTF8String](_, utfs => { // Throw away extra if more than 9 decimal places + val s = utfs.toString val periodIdx = s.indexOf(".") var n = s if (periodIdx != -1 && n.length() - periodIdx > 9) { @@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => - try DateUtils.fromJavaDate(Date.valueOf(s)) + buildCast[UTF8String](_, s => + try DateUtils.fromJavaDate(Date.valueOf(s.toString)) catch { case _: java.lang.IllegalArgumentException => null } ) case TimestampType => @@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toLong catch { + buildCast[UTF8String](_, s => try s.toString.toLong catch { case _: NumberFormatException => null }) case BooleanType => @@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toInt catch { + buildCast[UTF8String](_, s => try s.toString.toInt catch { case _: NumberFormatException => null }) case BooleanType => @@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toShort catch { + buildCast[UTF8String](_, s => try s.toString.toShort catch { case _: NumberFormatException => null }) case BooleanType => @@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toByte catch { + buildCast[UTF8String](_, s => try s.toString.toByte catch { case _: NumberFormatException => null }) case BooleanType => @@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => - buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(s.toString.toDouble), target) + } catch { case _: NumberFormatException => null }) case BooleanType => @@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toDouble catch { + buildCast[UTF8String](_, s => try s.toString.toDouble catch { case _: NumberFormatException => null }) case BooleanType => @@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toFloat catch { + buildCast[UTF8String](_, s => try s.toString.toFloat catch { case _: NumberFormatException => null }) case BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 47b6f358ed1b1..3475ed05f4454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericRow(newValues) } - override def update(ordinal: Int, value: Any): Unit = { - if (value == null) setNullAt(ordinal) else values(ordinal).update(value) + override def update(ordinal: Int, value: Any) { + if (value == null) { + setNullAt(ordinal) + } else { + values(ordinal).update(value) + } } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) - override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d141354a0f427..be2c101d63a63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children - case expressions.Literal(value: String, dataType) => + case expressions.Literal(value: UTF8String, dataType) => q""" val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value + val $primitiveTerm: ${termForType(dataType)} = + org.apache.spark.sql.types.UTF8String(${value.getBytes}) """.children case expressions.Literal(value: Int, dataType) => @@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children case Cast(child @ DateType(), StringType) => - child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType) + child.castOrNull(c => + q"""org.apache.spark.sql.types.UTF8String( + org.apache.spark.sql.types.DateUtils.toString($c))""", + StringType) case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - ${eval.primitiveTerm}.toString + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children + case EqualTo(e1: BinaryType, e2: BinaryType) => + (e1, e2).evaluateAs (BooleanType) { + case (eval1, eval2) => + q""" + java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], + $eval2.asInstanceOf[Array[Byte]]) + """ + } + case EqualTo(e1, e2) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } @@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val localLogger = log val localLoggerTree = reify { localLogger } q""" - $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + $localLoggerTree.debug( + ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) """ :: Nil } else { Nil @@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { + case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } @@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { + case StringType => q"$destinationRow.update($ordinal, $value)" case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } @@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "String" + case StringType => "org.apache.spark.sql.types.UTF8String" } protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => ru.Literal(Constant("")) + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 69397a73a8880..6f572ff959fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // getString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) return $elementName" :: Nil case _ => Nil } - - q""" - override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + // Row() need this interface to compile + case StringType => + q""" + override def getString(i: Int): String = { + $accessorFailure + }""" + case other => + q""" + override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } } val specificMutatorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // setString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } - - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + case StringType => + // MutableRow() need this interface to compile + q""" + override def setString(i: Int, value: String) { + $accessorFailure + }""" + case other => + q""" + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { + ..$ifStatements; + $accessorFailure + }""" + } } val hashValues = expressions.zipWithIndex.map { case (e,i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 860b72fad38b3..67caadb839ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ /** @@ -85,8 +85,11 @@ case class UserDefinedGenerator( override protected def makeOutput(): Seq[Attribute] = schema override def eval(input: Row): TraversableOnce[Row] = { + // TODO(davies): improve this + // Convert the objects into Scala Type before calling function, we need schema to support UDT + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) val inputRow = new InterpretedProjection(children) - function(inputRow(input)) + function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0e2d593e94124..18cba4cc46707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ object Literal { @@ -29,7 +30,7 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(s, StringType) + case s: String => Literal(UTF8String(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) @@ -42,7 +43,9 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } - def create(v: Any, dataType: DataType): Literal = Literal(v, dataType) + def create(v: Any, dataType: DataType): Literal = { + Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7e47cb3fffe12..fcd6352079b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison val r = right.eval(input) if (r == null) null else if (left.dataType != BinaryType) l == r - else BinaryType.ordering.compare( - l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 0a275b84086cf..b6ec7d3417ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{StructType, NativeType} - +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -37,6 +36,7 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + // TODO(davies): add setDate() and setDecimal() } /** @@ -114,9 +114,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - values(i).asInstanceOf[String] + values(i) match { + case null => null + case s: String => s + case utf8: UTF8String => utf8.toString + } } + // TODO(davies): add getDate and getDecimal + // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 @@ -189,8 +195,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - + override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } @@ -233,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return 0 } } + +object RowOrdering { + def forSchema(dataTypes: Seq[DataType]): RowOrdering = + new RowOrdering(dataTypes.zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index acfbbace608ef..d597bf7ce756a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import scala.collection.IndexedSeqOptimized - - import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types._ trait StringRegexExpression { self: BinaryExpression => @@ -60,38 +57,17 @@ trait StringRegexExpression { if(r == null) { null } else { - val regex = pattern(r.asInstanceOf[String]) + val regex = pattern(r.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { - matches(regex, l.asInstanceOf[String]) + matches(regex, l.asInstanceOf[UTF8String].toString) } } } } } -trait CaseConversionExpression { - self: UnaryExpression => - - type EvaluatedType = Any - - def convert(v: String): String - - override def foldable: Boolean = child.foldable - def nullable: Boolean = child.nullable - def dataType: DataType = StringType - - override def eval(input: Row): Any = { - val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - convert(evaluated.toString) - } - } -} - /** * Simple RegEx pattern matching function */ @@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } +trait CaseConversionExpression { + self: UnaryExpression => + + type EvaluatedType = Any + + def convert(v: UTF8String): UTF8String + + override def foldable: Boolean = child.foldable + def nullable: Boolean = child.nullable + def dataType: DataType = StringType + + override def eval(input: Row): Any = { + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.asInstanceOf[UTF8String]) + } + } +} + /** * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toUpperCase() + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def toString: String = s"Upper($child)" } @@ -149,7 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toLowerCase() + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def toString: String = s"Lower($child)" } @@ -162,15 +159,16 @@ trait StringComparison { override def nullable: Boolean = left.nullable || right.nullable - def compare(l: String, r: String): Boolean + def compare(l: UTF8String, r: UTF8String): Boolean override def eval(input: Row): Any = { - val leftEval = left.eval(input).asInstanceOf[String] + val leftEval = left.eval(input) if(leftEval == null) { null } else { - val rightEval = right.eval(input).asInstanceOf[String] - if (rightEval == null) null else compare(leftEval, rightEval) + val rightEval = right.eval(input) + if (rightEval == null) null + else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String]) } } @@ -184,7 +182,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.contains(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) } /** @@ -192,7 +190,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.startsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) } /** @@ -200,7 +198,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.endsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) } /** @@ -224,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends override def children: Seq[Expression] = str :: pos :: len :: Nil @inline - def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) - (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { - val len = str.length + def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -235,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => len + neg + case neg if neg < 0 => length() + neg case _ => 0 } @@ -244,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case x => start + x } - str.slice(start, end) + (start, end) } override def eval(input: Row): Any = { val string = str.eval(input) - val po = pos.eval(input) val ln = len.eval(input) @@ -257,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends null } else { val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - + val length = ln.asInstanceOf[Int] string match { - case ba: Array[Byte] => slice(ba, start, length) - case other => slice(other.toString, start, length) + case ba: Array[Byte] => + val (st, end) = slicePos(start, length, () => ba.length) + ba.slice(st, end) + case s: UTF8String => + val (st, end) = slicePos(start, length, () => s.length) + s.slice(st, end) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 93e69d409cb91..7c80634d2c852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] { val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case Like(l, Literal(endsWith(pattern), StringType)) => - EndsWith(l, Literal(pattern)) - case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case Like(l, Literal(equalTo(pattern), StringType)) => - EqualTo(l, Literal(pattern)) + case Like(l, Literal(utf, StringType)) => + utf.toString match { + case startsWith(pattern) if !pattern.endsWith("\\") => + StartsWith(l, Literal(pattern)) + case endsWith(pattern) => + EndsWith(l, Literal(pattern)) + case contains(pattern) if !pattern.endsWith("\\") => + Contains(l, Literal(pattern)) + case equalTo(pattern) => + EqualTo(l, Literal(pattern)) + case _ => + Like(l, Literal.create(utf, StringType)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 579a0fb8d3f93..ae4620a4e5abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.types.{ArrayType, StructType, StructField} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -111,10 +110,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolveChildren( - name: String, + nameParts: Seq[String], resolver: Resolver, throwErrors: Boolean = false): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver, throwErrors) + resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this @@ -122,10 +121,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * `[scope].AttributeName.[nested].[fields]...`. */ def resolve( - name: String, + nameParts: Seq[String], resolver: Resolver, throwErrors: Boolean = false): Option[NamedExpression] = - resolve(name, output, resolver, throwErrors) + resolve(nameParts, output, resolver, throwErrors) /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. @@ -135,7 +134,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsTableColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { assert(nameParts.length > 1) @@ -155,7 +154,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { if (resolver(attribute.name, nameParts.head)) { @@ -167,13 +166,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve( - name: String, + nameParts: Seq[String], input: Seq[Attribute], resolver: Resolver, throwErrors: Boolean): Option[NamedExpression] = { - val parts = name.split("\\.") - // A sequence of possible candidate matches. // Each candidate is a tuple. The first element is a resolved attribute, followed by a list // of parts that are to be resolved. @@ -182,9 +179,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // and the second element will be List("c"). var candidates: Seq[(Attribute, List[String])] = { // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (parts.length > 1) { + if (nameParts.length > 1) { input.flatMap { option => - resolveAsTableColumn(parts, resolver, option) + resolveAsTableColumn(nameParts, resolver, option) } } else { Seq.empty @@ -194,10 +191,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // If none of attributes match `table.column` pattern, we try to resolve it as a column. if (candidates.isEmpty) { candidates = input.flatMap { candidate => - resolveAsColumn(parts, resolver, candidate) + resolveAsColumn(nameParts, resolver, candidate) } } + def name = UnresolvedAttribute(nameParts).name + candidates.distinct match { // One match, no nested fields, use it. case Seq((a, Nil)) => Some(a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..fb4217a44807b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -94,6 +94,9 @@ sealed trait Partitioning { * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean + + /** Returns the expressions that are used to key the partitioning. */ + def keyExpressions: Seq[Expression] } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case UnknownPartitioning(_) => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object SinglePartition extends Partitioning { @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object BroadcastPartitioning extends Partitioning { @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } /** @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = expressions + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def eval(input: Row): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 504fb05842505..d36a49159b87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -40,6 +40,7 @@ object DateUtils { millisToDays(d.getTime) } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala new file mode 100644 index 0000000000000..fc02ba6c9c43e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -0,0 +1,214 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import java.util.Arrays + +/** + * A UTF-8 String, as internal representation of StringType in SparkSQL + * + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * + * Note: This is not designed for general use cases, should not be used outside SQL. + */ + +final class UTF8String extends Ordered[UTF8String] with Serializable { + + private[this] var bytes: Array[Byte] = _ + + /** + * Update the UTF8String with String. + */ + def set(str: String): UTF8String = { + bytes = str.getBytes("utf-8") + this + } + + /** + * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 + */ + def set(bytes: Array[Byte]): UTF8String = { + this.bytes = bytes + this + } + + /** + * Return the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + @inline + private[this] def numOfBytes(b: Byte): Int = { + val offset = (b & 0xFF) - 192 + if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 + } + + /** + * Return the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ + def length(): Int = { + var len = 0 + var i: Int = 0 + while (i < bytes.length) { + i += numOfBytes(bytes(i)) + len += 1 + } + len + } + + def getBytes: Array[Byte] = { + bytes + } + + /** + * Return a substring of this, + * @param start the position of first code point + * @param until the position after last code point + */ + def slice(start: Int, until: Int): UTF8String = { + if (until <= start || start >= bytes.length || bytes == null) { + new UTF8String + } + + var c = 0 + var i: Int = 0 + while (c < start && i < bytes.length) { + i += numOfBytes(bytes(i)) + c += 1 + } + var j = i + while (c < until && j < bytes.length) { + j += numOfBytes(bytes(j)) + c += 1 + } + UTF8String(Arrays.copyOfRange(bytes, i, j)) + } + + def contains(sub: UTF8String): Boolean = { + val b = sub.getBytes + if (b.length == 0) { + return true + } + var i: Int = 0 + while (i <= bytes.length - b.length) { + // In worst case, it's O(N*K), but should works fine with SQL + if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { + return true + } + i += 1 + } + false + } + + def startsWith(prefix: UTF8String): Boolean = { + val b = prefix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) + } + + def endsWith(suffix: UTF8String): Boolean = { + val b = suffix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) + } + + def toUpperCase(): UTF8String = { + // upper case depends on locale, fallback to String. + UTF8String(toString().toUpperCase) + } + + def toLowerCase(): UTF8String = { + // lower case depends on locale, fallback to String. + UTF8String(toString().toLowerCase) + } + + override def toString(): String = { + new String(bytes, "utf-8") + } + + override def clone(): UTF8String = new UTF8String().set(this.bytes) + + override def compare(other: UTF8String): Int = { + var i: Int = 0 + val b = other.getBytes + while (i < bytes.length && i < b.length) { + val res = bytes(i).compareTo(b(i)) + if (res != 0) return res + i += 1 + } + bytes.length - b.length + } + + override def compareTo(other: UTF8String): Int = { + compare(other) + } + + override def equals(other: Any): Boolean = other match { + case s: UTF8String => + Arrays.equals(bytes, s.getBytes) + case s: String => + // This is only used for Catalyst unit tests + // fail fast + bytes.length >= s.length && length() == s.length && toString() == s + case _ => + false + } + + override def hashCode(): Int = { + Arrays.hashCode(bytes) + } +} + +object UTF8String { + // number of tailing bytes in a UTF8 sequence for a code point + // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 + private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6) + + /** + * Create a UTF-8 String from String + */ + def apply(s: String): UTF8String = { + if (s != null) { + new UTF8String().set(s) + } else{ + null + } + } + + /** + * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 + */ + def apply(bytes: Array[Byte]): UTF8String = { + if (bytes != null) { + new UTF8String().set(bytes) + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index cdf2bc68d9c5e..c6fb22c26bd3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -350,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = String + private[sql] type JvmType = UTF8String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -1196,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** * Convert the user type to a SQL datum * - * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, - * where we need to convert Any to UserType. + * TODO: Can we make this take obj: UserType? The issue is in + * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. */ def serialize(obj: Any): Any diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6e3d6b9263e86..e10ddfdf5127c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import scala.collection.immutable - class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index d4362a91d992c..76298f03c94ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -25,8 +25,9 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -59,6 +60,10 @@ class ExpressionEvaluationBaseSuite extends FunSuite { class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { + def create_row(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } + test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) @@ -265,24 +270,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) - checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**"))) - checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b"))) - checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, - new GenericRow(Array[Any]("bc%"))) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) } test("RLIKE literal Regular Expression") { @@ -313,14 +317,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("RLIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c"))) - checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo"))) - checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$"))) - checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n"))) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) + evaluate("abbbbc" rlike regEx, create_row("**")) } } @@ -763,7 +767,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("null checking") { - val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) + val row = create_row("^Ba*n", null, true, null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) val c3 = 'a.boolean.at(2) @@ -803,7 +807,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("case when") { - val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) val c2 = 'a.boolean.at(1) val c3 = 'a.boolean.at(2) @@ -846,13 +850,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("complex type") { - val row = new GenericRow(Array[Any]( - "^Ba*n", // 0 - null.asInstanceOf[String], // 1 - new GenericRow(Array[Any]("aa", "bb")), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - )) + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 + ) val typeS = StructType( StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil @@ -909,7 +913,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("arithmetic") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -934,7 +938,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("fractional arithmetic") { - val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val row = create_row(1.1, 2.0, 3.1, null) val c1 = 'a.double.at(0) val c2 = 'a.double.at(1) val c3 = 'a.double.at(2) @@ -958,7 +962,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) + val row = create_row(1, 2, 3, null, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -988,7 +992,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("StringComparison") { - val row = new GenericRow(Array[Any]("abc", null)) + val row = create_row("abc", null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) @@ -1009,7 +1013,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("Substring") { - val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + val row = create_row("example", "example".toArray.map(_.toByte)) val s = 'a.string.at(0) @@ -1053,7 +1057,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // substring(null, _, _) -> null checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - null, new GenericRow(Array[Any](null))) + null, create_row(null)) // substring(_, null, _) -> null checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), @@ -1102,20 +1106,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("SQRT") { val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) val d = 'a.double.at(0) for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 275ea2627ebcd..bcc0c404d2cfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen._ /** @@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala new file mode 100644 index 0000000000000..a22aa6f244c48 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -0,0 +1,70 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +// scalastyle:off +class UTF8StringSuite extends FunSuite { + test("basic") { + def check(str: String, len: Int) { + + assert(UTF8String(str).length == len) + assert(UTF8String(str.getBytes("utf8")).length() == len) + + assert(UTF8String(str) == str) + assert(UTF8String(str.getBytes("utf8")) == str) + assert(UTF8String(str).toString == str) + assert(UTF8String(str.getBytes("utf8")).toString == str) + assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) + + assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) + } + + check("hello", 5) + check("世 界", 3) + } + + test("contains") { + assert(UTF8String("hello").contains(UTF8String("ello"))) + assert(!UTF8String("hello").contains(UTF8String("vello"))) + assert(UTF8String("大千世界").contains(UTF8String("千世"))) + assert(!UTF8String("大千世界").contains(UTF8String("世千"))) + } + + test("prefix") { + assert(UTF8String("hello").startsWith(UTF8String("hell"))) + assert(!UTF8String("hello").startsWith(UTF8String("ell"))) + assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) + assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) + } + + test("suffix") { + assert(UTF8String("hello").endsWith(UTF8String("ello"))) + assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) + assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) + assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) + } + + test("slice") { + assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) + assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) + assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) + assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index ca4a127120b37..18584c2dcf797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData(dataIndex).cachedRepresentation.uncache(blocking) cachedData.remove(dataIndex) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 94ae2d65fd0e4..3235f85d5bbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -158,7 +158,7 @@ class DataFrame private[sql]( } protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } @@ -166,7 +166,7 @@ class DataFrame private[sql]( protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ee641bdfeb2d7..5c65f04ee8497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -47,6 +47,7 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -128,6 +129,13 @@ private[sql] class SQLConf extends Serializable { /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + /** + * Sort merge join would sort the two side of join first, and then iterate both sides together + * only once to get all matches. Using sort merge join can save a lot of memory usage compared + * to HashJoin. + */ + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c25ef58e6f62a..f9f3eb2e03817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -873,8 +873,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * passed to this function. * * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` to retrieve - * @param upperBound the maximum value of `columnName` to retrieve + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: Nil + Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil } protected[sql] def openSession(): SQLSession = { @@ -1195,6 +1195,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case FloatType => true case DateType => true case TimestampType => true + case StringType => true case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c881747751520..00ed70430b84d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -153,6 +153,7 @@ private[sql] object ColumnBuilder { val builder: ColumnBuilder = dataType match { case IntegerType => new IntColumnBuilder case LongType => new LongColumnBuilder + case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder case BooleanType => new BooleanColumnBuilder case ByteType => new ByteColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 87a6631da8300..b0f983c180673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -216,13 +216,13 @@ private[sql] class IntColumnStats extends ColumnStats { } private[sql] class StringColumnStats extends ColumnStats { - protected var upper: String = null - protected var lower: String = null + protected var upper: UTF8String = null + protected var lower: UTF8String = null override def gatherStats(row: Row, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getString(ordinal) + val value = row(ordinal).asInstanceOf[UTF8String] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index c47497e0662d9..1b9e0df2dcb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -312,26 +312,28 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.getString(ordinal).getBytes("utf-8").length + 4 } - override def append(v: String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes("utf-8") + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + val stringBytes = v.getBytes buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer): String = { + override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - new String(stringBytes, "utf-8") + UTF8String(stringBytes) } - override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { - row.setString(ordinal, value) + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): String = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): UTF8String = { + row(ordinal).asInstanceOf[UTF8String] + } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.setString(toOrdinal, from.getString(fromOrdinal)) + to.update(toOrdinal, from(fromOrdinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eee0c86d6a1c..d9b6fb43ab83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.apache.spark.Accumulator +import org.apache.spark.{Accumulable, Accumulator, Accumulators} import org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row +import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation( child: SparkPlan, tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null) + private var _statistics: Statistics = null, + private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null) extends LogicalPlan with MultiInstanceRelation { - private val batchStats = - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + private val batchStats: Accumulable[ArrayBuffer[Row], Row] = + if (_batchStats == null) { + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + } else { + _batchStats + } val partitionStatistics = new PartitionStatistics(output) @@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, statisticsToBePropagated) + _cachedColumnBuffers, statisticsToBePropagated, batchStats) } override def children: Seq[LogicalPlan] = Seq.empty @@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - statisticsToBePropagated).asInstanceOf[this.type] + statisticsToBePropagated, + batchStats).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, statisticsToBePropagated) + Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats) + + private[sql] def uncache(blocking: Boolean): Unit = { + Accumulators.remove(batchStats.id) + cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } } private[sql] case class InMemoryColumnarTableScan( @@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan( } } + lazy val enableAccumulators: Boolean = + sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean + // Accumulators used for testing purposes - val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) - val readBatches: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning override def execute(): RDD[Row] = { - readPartitions.setValue(0) - readBatches.setValue(0) + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( @@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan( } } - if (rows.hasNext) { + if (rows.hasNext && enableAccumulators) { readPartitions += 1 } @@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan( logInfo(s"Skipping partition based on stats $statsString") false } else { - readBatches += 1 + if (enableAccumulators) { + readBatches += 1 + } true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..69a620e1ec929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,24 +19,42 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair +object Exchange { + /** + * Returns true when the ordering expressions are a subset of the key. + * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. + */ + def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { + desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) + } +} + /** * :: DeveloperApi :: + * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each + * resulting partition based on expressions from the partition key. It is invalid to construct an + * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + newPartitioning: Partitioning, + newOrdering: Seq[SortOrder], + child: SparkPlan) + extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning + override def outputOrdering: Seq[SortOrder] = newOrdering + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ @@ -45,7 +63,23 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val keyOrdering = { + if (newOrdering.nonEmpty) { + val key = newPartitioning.keyExpressions + val boundOrdering = newOrdering.map { o => + val ordinal = key.indexOf(o.child) + if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") + o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + } + new RowOrdering(boundOrdering) + } else { + null // Ordering will not be used + } + } + override def execute(): RDD[Row] = attachTree(this , "execute") { + lazy val sparkConf = child.sqlContext.sparkContext.getConf + newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -56,7 +90,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold + + val rdd = if (willMergeSort || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -69,12 +105,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Row, Row](rdd, part) + } + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - val rdd = if (sortBasedShuffleOn) { + val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { child.execute().mapPartitions { iter => @@ -87,9 +128,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Null, Null](rdd, part) + } + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._1) case SinglePartition => @@ -107,7 +152,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") @@ -120,27 +165,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. + * each operator by inserting [[Exchange]] Operators where required. Also ensure that the + * required input partition ordering requirements are met. */ -private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // Check if every child's outputPartitioning satisfies the corresponding + // True iff every child's outputPartitioning satisfies the corresponding // required data distribution. def meetsRequirements: Boolean = - !operator.requiredChildDistribution.zip(operator.children).map { + operator.requiredChildDistribution.zip(operator.children).forall { case (required, child) => val valid = child.outputPartitioning.satisfies(required) logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid - }.exists(!_) + } - // Check if outputPartitionings of children are compatible with each other. + // True iff any of the children are incorrectly sorted. + def needsAnySort: Boolean = + operator.requiredChildOrdering.zip(operator.children).exists { + case (required, child) => required.nonEmpty && required != child.outputOrdering + } + + // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, // A dataset is range partitioned by "a.asc" (RangePartitioning) and another @@ -157,28 +209,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case Seq(a,b) => a compatibleWith b }.exists(!_) - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = - if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + // Adds Exchange or Sort operators as required + def addOperatorsIfNecessary( + partitioning: Partitioning, + rowOrdering: Seq[SortOrder], + child: SparkPlan): SparkPlan = { + val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering + val needsShuffle = child.outputPartitioning != partitioning + val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) + + if (needSort && needsShuffle && canSortWithShuffle) { + Exchange(partitioning, rowOrdering, child) + } else { + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) + } else { + child + } + + val withSort = if (needSort) { + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) + } else { + Sort(rowOrdering, global = false, withShuffle) + } + } else { + withShuffle + } + + withSort + } + } - if (meetsRequirements && compatible) { + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { - case (AllTuples, child) => - addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), child) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (OrderedDistribution(ordering), child) => - addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, child) => child - case (dist, _) => sys.error(s"Don't know how to ensure $dist") + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, child) + } else { + Sort(rowOrdering, global = false, child) + } + + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(repartitionedChildren) + + operator.withNewChildren(fixedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 656bdd7212f56..1fd387eec7e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} /** * :: DeveloperApi :: @@ -54,6 +54,33 @@ object RDDConversions { } } } + + /** + * Convert the objects inside Row into the types Catalyst expected. + */ + def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) + val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = converters(i)(r(i)) + i += 1 + } + + mutableRow + } + } + } + } } /** Logical plan node for scanning data from an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fabcf6b4a0570..e159ffe66cb24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + /** Specifies how data is ordered in each partition. */ + def outputOrdering: Seq[SortOrder] = Nil + + /** Specifies sort order for each partition requirements on the input data for this operator. */ + def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** * Runs this query returning the result as an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 914f387dec78f..eea15aff5dbcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -65,12 +65,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co private[execution] class KryoResourcePool(size: Int) extends ResourcePool[SerializerInstance](size) { - val ser: KryoSerializer = { + val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - new KryoSerializer(sparkConf) + new SparkSqlSerializer(sparkConf) } def newInstance(): SerializerInstance = ser.newInstance() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5b99e40c2f491..e687d01f57520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { @@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f8221f41bc6c3..d286fe81bee5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -41,6 +41,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends val resuableProjection = buildProjection() iter.map(resuableProjection) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -55,6 +57,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -117,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan) } val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) shuffled.mapPartitions(_.take(limit).map(_._2)) } } @@ -147,6 +151,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -172,6 +178,8 @@ case class Sort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -202,6 +210,8 @@ case class ExternalSort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index fad7a281dc1e2..99f24910fd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray - override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) + override def execute(): RDD[Row] = { + val converted = sideEffectResult.map(r => + CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) + sqlContext.sparkContext.parallelize(converted, 1) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e916e68e58b5d..710787096e6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -164,7 +164,7 @@ package object debug { case (_: Long, LongType) => case (_: Int, IntegerType) => - case (_: String, StringType) => + case (_: UTF8String, StringType) => case (_: Float, FloatType) => case (_: Byte, ByteType) => case (_: Short, ShortType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala new file mode 100644 index 0000000000000..b5123668ba11e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + */ +@DeveloperApi +case class SortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + override def execute(): RDD[Row] = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 5b308d88d4cdf..7a43bfd8bc8d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -140,6 +140,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal case (other, _) => other @@ -192,7 +193,8 @@ object EvaluatePython { case (c: Long, IntegerType) => c.toInt case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString + case (c: String, StringType) => UTF8String(c) + case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) case (c, _) => c } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 463e1dcc268bc..b9022fcd9e3ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -233,7 +233,7 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" + case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" case _ => value } @@ -349,12 +349,14 @@ private[sql] class JDBCRDD( val pos = i + 1 conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + // TODO(davies): convert Date into Int case DateConversion => mutableRow.update(i, rs.getDate(pos)) case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 case StringConversion => mutableRow.setString(i, rs.getString(pos)) case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4fa84dc076f7e..99b755c9f25d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -130,6 +130,8 @@ private[sql] case class JDBCRelation( extends BaseRelation with PrunedFilteredScan { + override val needConversion: Boolean = false + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 34f864f5fda7a..d4e0abc040bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql import java.sql.{Connection, DriverManager, PreparedStatement} -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql._ -import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition} +import org.apache.spark.Logging import org.apache.spark.sql.types._ package object jdbc { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index f4c99b4b56606..e3352d02787fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row - -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} private[sql] class DefaultSource @@ -113,6 +113,8 @@ private[sql] case class JSONRelation( // TODO: Support partitioned JSON relation. private def baseRDD = sqlContext.sparkContext.textFile(path) + override val needConversion: Boolean = false + override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( JsonRDD.inferSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index b1e8521383756..29de7401dda71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case StringType => toString(value) + case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala deleted file mode 100644 index 25a66cb488103..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - -import parquet.Log -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter} - -private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - - override def getWorkPath(): Path = outputPath - override def abortTask(taskContext: TaskAttemptContext): Unit = {} - override def commitTask(taskContext: TaskAttemptContext): Unit = {} - override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true - override def setupJob(jobContext: JobContext): Unit = {} - override def setupTask(taskContext: TaskAttemptContext): Unit = {} - - override def commitJob(jobContext: JobContext) { - try { - val configuration = ContextUtil.getConfiguration(jobContext) - val fileSystem = outputPath.getFileSystem(configuration) - val outputStatus = fileSystem.getFileStatus(outputPath) - val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) - try { - ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { - val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) - fileSystem.create(successPath).close() - } - } catch { - case e: Exception => { - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } - } - } - } catch { - case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) - } - } - -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 43ca359b51735..bc108e37dfb0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - updateField(fieldIndex, value) + protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + updateField(fieldIndex, UTF8String(value)) protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, readTimestamp(value)) @@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - current.setString(fieldIndex, value) + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + current.update(fieldIndex, UTF8String(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, readTimestamp(value)) @@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter( private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - private[this] var dict: Array[String] = null + private[this] var dict: Array[Array[Byte]] = null override def hasDictionarySupport: Boolean = true override def setDictionary(dictionary: Dictionary):Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} - + dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = parent.updateString(fieldIndex, dict(dictionaryId)) override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.toStringUsingUTF8) + parent.updateString(fieldIndex, value.getBytes) } private[parquet] object CatalystArrayConverter { @@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] + buffer(elements) = UTF8String(value).asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 0357dcc4688be..5eb1c6abc2432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -55,7 +55,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -76,7 +76,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -94,7 +94,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -111,7 +111,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -128,7 +128,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -145,7 +145,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 3724bda829d30..1c868da23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -379,8 +379,6 @@ private[sql] case class InsertIntoParquetTable( */ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { - var committer: OutputCommitter = null - // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} @@ -405,26 +403,6 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } - - // override to create output committer from configuration - override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - if (committer == null) { - val output = getOutputPath(context) - val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", - classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) - val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] - } - committer - } - - // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 - private def getOutputPath(context: TaskAttemptContext): Path = { - context.getConfiguration().get("mapred.output.dir") match { - case null => null - case name => new Path(name) - } - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 5a1b15490d273..e05a4c20b0d41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { if (value != null) { schema match { case StringType => writer.addBinary( - Binary.fromByteArray( - value.asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) @@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 20fdf5e58ef82..af7b3c81ae7b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} - import parquet.filter2.predicate.FilterApi import parquet.format.converter.ParquetMetadataConverter import parquet.hadoop.metadata.CompressionCodecName @@ -45,13 +44,13 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} -import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition} /** * Allows creation of Parquet based tables using the syntax: @@ -409,6 +408,9 @@ private[sql] case class ParquetRelation2( file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + // Skip type conversion + override val needConversion: Boolean = false + // TODO Should calculate per scan size // It's common that a query only scans a fraction of a large Parquet file. Returning size of the // whole Parquet file disables some optimizations in this case (e.g. broadcast join). @@ -550,7 +552,8 @@ private[sql] case class ParquetRelation2( baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => val partValues = selectedPartitions.collectFirst { - case p if split.getPath.getParent.toString == p.path => p.values + case p if split.getPath.getParent.toString == p.path => + CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row] }.get val requiredPartOrdinal = partitionKeyLocations.keys.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 34d048e426d10..b3d71f687a60a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{UTF8String, StringType} import org.apache.spark.sql.{Row, Strategy, execution, sources} /** @@ -53,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy { (a, _) => t.buildScan(a)) :: Nil case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => @@ -102,20 +103,30 @@ private[sql] object DataSourceStrategy extends Strategy { projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = - execution.PhysicalRDD( - projectList.map(_.toAttribute), + val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = - execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) + val scan = createPhysicalRDD(relation.relation, requestedColumns, + scanBuilder(requestedColumns, pushedFilters)) execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } + private[this] def createPhysicalRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): SparkPlan = { + val converted = if (relation.needConversion) { + execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + } else { + rdd + } + execution.PhysicalRDD(output, converted) + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, * and convert them. @@ -167,14 +178,14 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.Not(child) => translate(child).map(sources.Not) - case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringStartsWith(a.name, v)) + case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(a.name, v.toString)) - case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringEndsWith(a.name, v)) + case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(a.name, v.toString)) - case expressions.Contains(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringContains(a.name, v)) + case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(a.name, v.toString)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 319de710fbc3e..2e861b84b7133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import scala.util.matching.Regex import scala.language.implicitConversions import org.apache.spark.Logging @@ -155,7 +156,19 @@ private[sql] class DDLParser( protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k,v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8f9946a5a801e..ca53dcdb92c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -126,6 +126,16 @@ abstract class BaseRelation { * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes + + /** + * Whether does it need to convert the objects in Row to internal representation, for example: + * java.lang.String -> UTF8String + * java.lang.Decimal -> Decimal + * + * Note: The internal representation is not stable across releases and thus data sources outside + * of Spark SQL should leave this as true. + */ + def needConversion: Boolean = true } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index f7b5f08beb92f..01e3b8671071e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ @@ -297,4 +298,21 @@ class CachedTableSuite extends QueryTest { sql("Clear CACHE") assert(cacheManager.isEmpty) } + + test("Clear accumulators when uncacheTable to prevent memory leaking") { + val accsSize = Accumulators.originals.size + + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + uncacheTable("t1") + uncacheTable("t2") + + assert(accsSize >= Accumulators.originals.size) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e4dee87849fd4..037d392c1f929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j } assert(operators.size === 1) @@ -62,6 +63,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { cacheManager.clearCache() + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -91,17 +93,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } } test("broadcasted hash join operator selection") { cacheManager.clearCache() sql("CACHE TABLE testData") + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9a81fc5d72819..59f9508444f25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -104,9 +104,12 @@ object QueryTest { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 36465cc2fa11a..bf6cf1321a056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,7 +30,7 @@ class RowSuite extends FunSuite { test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) - expected.update(1, "this is a string") + expected.setString(1, "this is a string") expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 73fb791c3ead7..d739e550f3e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.execution.GeneratedAggregate -import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ - import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.types._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. @@ -431,6 +426,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } + test("Allow only a single WITH clause per query") { + intercept[RuntimeException] { + sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + } + } + test("date row") { checkAnswer(sql( """select cast("2015-01-28" as date) from testData limit 1"""), @@ -1123,7 +1124,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) jsonRDD(data).registerTempTable("records") - sql("SELECT `key?number1` FROM records") + sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { @@ -1223,4 +1224,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } + + test("SPARK-6898: complete support for special chars in column names") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 5f08834f73c6b..b48bed1871c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import java.sql.Timestamp +import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.serializer.KryoRegistrator import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -65,7 +68,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) @@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType[BooleanType.type]( @@ -108,8 +111,8 @@ class ColumnTypeSuite extends FunSuite with Logging { testNativeColumnType[StringType.type]( STRING, - (buffer: ByteBuffer, string: String) => { - val bytes = string.getBytes("utf-8") + (buffer: ByteBuffer, string: UTF8String) => { + val bytes = string.getBytes buffer.putInt(bytes.length) buffer.put(bytes) }, @@ -117,7 +120,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - new String(bytes, "utf-8") + UTF8String(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( @@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging { } } + test("CUSTOM") { + val conf = new SparkConf() + conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") + val serializer = new SparkSqlSerializer(conf).newInstance() + + val buffer = ByteBuffer.allocate(512) + val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val serializedObj = serializer.serialize(obj).array() + + GENERIC.append(serializer.serialize(obj).array(), buffer) + buffer.rewind() + + val length = buffer.getInt + assert(length === serializedObj.length) + assert(13 == length) // id (1) + int (4) + long (8) + + val genericSerializedObj = SparkSqlSerializer.serialize(obj) + assert(length != genericSerializedObj.length) + assert(length < genericSerializedObj.length) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + serializer.deserialize(ByteBuffer.wrap(bytes)) + } + + buffer.rewind() + buffer.putInt(serializedObj.length).put(serializedObj) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + buffer.rewind() + serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + } + } + def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, @@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging { } } } + +private[columnar] final case class CustomClass(a: Int, b: Long) + +private[columnar] object CustomerSerializer extends Serializer[CustomClass] { + override def write(kryo: Kryo, output: Output, t: CustomClass) { + output.writeInt(t.a) + output.writeLong(t.b) + } + override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { + val a = input.readInt() + val b = input.readLong() + CustomClass(a,b) + } +} + +private[columnar] final class Registrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[CustomClass], CustomerSerializer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index b301818a008e7..f76314b9dab5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{Decimal, DataType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType} object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -48,7 +48,7 @@ object ColumnarTestUtils { case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case STRING => Random.nextString(Random.nextInt(32)) + case STRING => UTF8String(Random.nextString(Random.nextInt(32))) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) case DATE => Random.nextInt() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 479210d1c9c43..56591d9dba29e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.columnar +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.types.{DecimalType, Decimal} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -132,4 +134,59 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM test_fixed_decimal"), (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) } + + test("test different data types") { + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD for the schema + val rdd = + sparkContext.parallelize((1 to 100), 10).map { i => + Row( + s"str${i}: test cache.", + s"binary${i}: test cache.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i), + (1 to i).toSeq, + (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + Row((i - 0.25).toFloat, (1 to i).toSeq)) + } + createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + // Cache the table. + sql("cache table InMemoryCache_different_data_types") + // Make sure the table is indeed cached. + val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + assert( + isCached("InMemoryCache_different_data_types"), + "InMemoryCache_different_data_types should be cached.") + // Issue a query and check the results. + checkAnswer( + sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + table("InMemoryCache_different_data_types").collect()) + dropTempTable("InMemoryCache_different_data_types") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index e57bb06e7263b..2a0b701cad7fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + // Enable in-memory table scan accumulators + setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 4d0bf7cf99cdf..97c0f439acf13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,27 +381,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } - - test("SPARK-6352 DirectParquetOutputCommitter") { - // Write to a parquet file and let it fail. - // _temporary should be missing if direct output committer works. - try { - configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") - sqlContext.udf.register("div0", (x: Int) => x / 0) - withTempPath { dir => - intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) - } - val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) - assert(!fs.exists(path)) - } - } - finally { - configuration.unset("spark.sql.parquet.output.committer.class") - } - } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 3f24a497390c1..ca25751b9583d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -25,17 +25,17 @@ class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext) + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) } } -case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, - new MetadataBuilder().putString("comment", "test comment").build()), + new MetadataBuilder().putString("comment", s"test comment $table").build()), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), @@ -73,7 +73,8 @@ class DDLTestSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.DDLScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | Table 'test1' |) """.stripMargin) } @@ -81,7 +82,7 @@ class DDLTestSuite extends DataSourceTest { sqlTest( "describe ddlPeople", Seq( - Row("intType", "int", "test comment"), + Row("intType", "int", "test comment test1"), Row("stringType", "string", ""), Row("dateType", "date", ""), Row("timestampType", "timestamp", ""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 60c8c00bda4d5..3b47b8adf313b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -74,7 +74,7 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -82,7 +82,7 @@ case class AllDataTypesScan( Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) } } } @@ -103,7 +103,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -111,7 +111,7 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) }.toSeq before { @@ -266,7 +266,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) + (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) test("Caching") { // Cached Query Execution diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index a96b1ffc26966..f38c796241df1 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -44,7 +44,6 @@ com.google.guava guava - runtime ${hive.group} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 62c061bef690a..85281c6d73a3b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -145,6 +145,9 @@ private[hive] object SparkSQLCLIDriver { case e: UnsupportedEncodingException => System.exit(3) } + // use the specified database if specified + cli.processSelectDatabase(sessionState); + // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 6d1d7c3a4e698..b070fa8eaa469 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,22 +25,31 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging import org.apache.spark.util.Utils -class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { +class CliSuite extends FunSuite with BeforeAndAfter with Logging { + val warehousePath = Utils.createTempDir() + val metastorePath = Utils.createTempDir() + + before { + warehousePath.delete() + metastorePath.delete() + } + + after { + warehousePath.delete() + metastorePath.delete() + } + def runCliWithin( timeout: FiniteDuration, extraArgs: Seq[String] = Seq.empty)( - queriesAndExpectedAnswers: (String, String)*) { + queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val warehousePath = Utils.createTempDir() - warehousePath.delete() - val metastorePath = Utils.createTempDir() - metastorePath.delete() val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val command = { @@ -95,8 +104,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin, cause) throw cause } finally { - warehousePath.delete() - metastorePath.delete() process.destroy() } } @@ -124,4 +131,24 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { test("Single command with -e") { runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } + + test("Single command with --database") { + runCliWithin(1.minute)( + "CREATE DATABASE hive_test_db;" + -> "OK", + "USE hive_test_db;" + -> "OK", + "CREATE TABLE hive_test(key INT, val STRING);" + -> "OK", + "SHOW TABLES;" + -> "Time taken: " + ) + + runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + "" + -> "OK", + "" + -> "hive_test" + ) + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala new file mode 100644 index 0000000000000..65d070bd3cbde --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join is true. + */ +class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + super.afterAll() + } + + override def whiteList = Seq( + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_smb_mapjoin_14", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join32_lessSize", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_map_ppr", + "join_nulls", + "join_nullsafe", + "join_rc", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star" + ) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 921c6194c7b76..74ae984f34866 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -34,7 +34,7 @@ import scala.collection.JavaConversions._ * 1. The Underlying data type in catalyst and in Hive * In catalyst: * Primitive => - * java.lang.String + * UTF8String * int / scala.Int * boolean / scala.Boolean * float / scala.Float @@ -239,9 +239,10 @@ private[hive] trait HiveInspectors { */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null - case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString + case poi: WritableConstantStringObjectInspector => + UTF8String(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - poi.getWritableConstantValue.getHiveVarchar.getValue + UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -284,10 +285,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + case hvoi: HiveVarcharObjectInspector => + UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).toString + UTF8String(x.getPrimitiveWritableObject(data).toString) + case x: StringObjectInspector => + UTF8String(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -340,7 +344,9 @@ private[hive] trait HiveInspectors { */ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + (o: Any) => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) @@ -409,7 +415,7 @@ private[hive] trait HiveInspectors { case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[java.lang.String] + case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 53a204b8c2932..fd305eb480e63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1101,7 +1101,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) + UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) case other => UnresolvedGetField(other, attr) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 1ccb0c279c60e..a6f4fbe8aba06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,24 +17,21 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row - import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.expressions.{Row, _} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing} +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.types.StringType @@ -131,7 +128,7 @@ private[hive] trait HiveStrategies { val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) + inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) i += 1 } pruningCondition(inputData) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8efed7f0299bf..cab0fdd35723a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, InputStreamReader} -import java.io.{DataInputStream, DataOutputStream, EOFException} +import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} import java.util.Properties import scala.collection.JavaConversions._ @@ -28,12 +27,13 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -121,14 +121,13 @@ case class ScriptTransformation( if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 902a12785e3e9..a40a1e53117cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -22,11 +22,11 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** * Analyzes the given table in the current database to generate statistics, which will be @@ -76,6 +76,12 @@ case class DropTable( private[hive] case class AddJar(path: String) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("result", IntegerType, false) :: Nil) + schema.toAttributes + } + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 0ed93c2c5b1fa..33e96eaabfbf6 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -41,7 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} private[hive] case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { @@ -135,7 +135,7 @@ private[hive] object HiveShim { PrimitiveCategory.VOID, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 7577309900209..d331c210e8939 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,37 +17,35 @@ package org.apache.spark.sql.hive -import java.util -import java.util.{ArrayList => JArrayList} -import java.util.Properties import java.rmi.server.UID +import java.util.{Properties, ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import com.esotericsoftware.kryo.Kryo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.{HiveDecimal} +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType} - +import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} /** * This class provides the UDF creation and also the UDF instance serialization and @@ -63,18 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import java.io.{OutputStream, InputStream} - import com.esotericsoftware.kryo.Kryo import org.apache.spark.util.Utils._ - import org.apache.hadoop.hive.ql.exec.Utilities - import org.apache.hadoop.hive.ql.exec.UDF @transient private val methodDeSerialize = { val method = classOf[Utilities].getDeclaredMethod( "deserializeObjectByKryo", classOf[Kryo], - classOf[InputStream], + classOf[java.io.InputStream], classOf[Class[_]]) method.setAccessible(true) @@ -87,7 +81,7 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) "serializeObjectByKryo", classOf[Kryo], classOf[Object], - classOf[OutputStream]) + classOf[java.io.OutputStream]) method.setAccessible(true) method @@ -224,7 +218,7 @@ private[hive] object HiveShim { TypeInfoFactory.voidTypeInfo, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala new file mode 100644 index 0000000000000..df1c0a10704c3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import scala.xml.Node + +import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.ui.UIUtils + +private[ui] abstract class BatchTableBase(tableId: String) { + + protected def columns: Seq[Node] = { + Batch Time + Input Size + Scheduling Delay + Processing Time + } + + protected def baseRow(batch: BatchInfo): Seq[Node] = { + val batchTime = batch.batchTime.milliseconds + val formattedBatchTime = UIUtils.formatDate(batch.batchTime.milliseconds) + val eventCount = batch.receivedBlockInfo.values.map { + receivers => receivers.map(_.numRecords).sum + }.sum + val schedulingDelay = batch.schedulingDelay + val formattedSchedulingDelay = schedulingDelay.map(UIUtils.formatDuration).getOrElse("-") + val processingTime = batch.processingDelay + val formattedProcessingTime = processingTime.map(UIUtils.formatDuration).getOrElse("-") + + {formattedBatchTime} + {eventCount.toString} events + + {formattedSchedulingDelay} + + + {formattedProcessingTime} + + } + + private def batchTable: Seq[Node] = { + + + {columns} + + + {renderRows} + +
    + } + + def toNodeSeq: Seq[Node] = { + batchTable + } + + /** + * Return HTML for all rows of this table. + */ + protected def renderRows: Seq[Node] +} + +private[ui] class ActiveBatchTable(runningBatches: Seq[BatchInfo], waitingBatches: Seq[BatchInfo]) + extends BatchTableBase("active-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Status + + override protected def renderRows: Seq[Node] = { + // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display + // waiting batches before running batches + waitingBatches.flatMap(batch => {waitingBatchRow(batch)}) ++ + runningBatches.flatMap(batch => {runningBatchRow(batch)}) + } + + private def runningBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ processing + } + + private def waitingBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ queued + } +} + +private[ui] class CompletedBatchTable(batches: Seq[BatchInfo]) + extends BatchTableBase("completed-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Total Delay + + override protected def renderRows: Seq[Node] = { + batches.flatMap(batch => {completedBatchRow(batch)}) + } + + private def completedBatchRow(batch: BatchInfo): Seq[Node] = { + val totalDelay = batch.totalDelay + val formattedTotalDelay = totalDelay.map(UIUtils.formatDuration).getOrElse("-") + baseRow(batch) ++ + + {formattedTotalDelay} + + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index b6dcb62bfeec8..07fa285642eec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -41,7 +41,8 @@ private[ui] class StreamingPage(parent: StreamingTab) generateBasicStats() ++

    ++

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ - generateBatchStatsTable() + generateBatchStatsTable() ++ + generateBatchListTables() } UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } @@ -49,9 +50,10 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Generate basic stats of the streaming program */ private def generateBasicStats(): Seq[Node] = { val timeSinceStart = System.currentTimeMillis() - startTime + // scalastyle:off
    • - Started at: {startTime.toString} + Started at: {UIUtils.formatDate(startTime)}
    • Time since start: {formatDurationVerbose(timeSinceStart)} @@ -63,18 +65,19 @@ private[ui] class StreamingPage(parent: StreamingTab) Batch interval: {formatDurationVerbose(listener.batchDuration)}
    • - Processed batches: {listener.numTotalCompletedBatches} + Completed batches: {listener.numTotalCompletedBatches}
    • - Waiting batches: {listener.numUnprocessedBatches} + Active batches: {listener.numUnprocessedBatches}
    • - Received records: {listener.numTotalReceivedRecords} + Received events: {listener.numTotalReceivedRecords}
    • - Processed records: {listener.numTotalProcessedRecords} + Processed events: {listener.numTotalProcessedRecords}
    + // scalastyle:on } /** Generate stats of data received by the receivers in the streaming program */ @@ -86,10 +89,10 @@ private[ui] class StreamingPage(parent: StreamingTab) "Receiver", "Status", "Location", - "Records in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", - "Minimum rate\n[records/sec]", - "Median rate\n[records/sec]", - "Maximum rate\n[records/sec]", + "Events in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", + "Minimum rate\n[events/sec]", + "Median rate\n[events/sec]", + "Maximum rate\n[events/sec]", "Last Error" ) val dataRows = (0 until listener.numReceivers).map { receiverId => @@ -190,5 +193,26 @@ private[ui] class StreamingPage(parent: StreamingTab) } UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) } + + private def generateBatchListTables(): Seq[Node] = { + val runningBatches = listener.runningBatches.sortBy(_.batchTime.milliseconds).reverse + val waitingBatches = listener.waitingBatches.sortBy(_.batchTime.milliseconds).reverse + val completedBatches = listener.retainedCompletedBatches. + sortBy(_.batchTime.milliseconds).reverse + + val activeBatchesContent = { +

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ + new ActiveBatchTable(runningBatches, waitingBatches).toNodeSeq + } + + val completedBatchesContent = { +

    + Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) +

    ++ + new CompletedBatchTable(completedBatches).toNodeSeq + } + + activeBatchesContent ++ completedBatchesContent + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 998426ebb82e5..205ddf6dbe9b0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -75,6 +75,17 @@ class UISeleniumSuite val statisticText = findAll(cssSelector("li strong")).map(_.text).toSeq statisticText should contain("Network receivers:") statisticText should contain("Batch interval:") + + val h4Text = findAll(cssSelector("h4")).map(_.text).toSeq + h4Text should contain("Active Batches (0)") + h4Text should contain("Completed Batches (last 0 out of 0)") + + findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Status") + } + findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Total Delay") + } } ssc.stop(false) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index c06c0105670c0..a18c94d4ab4a8 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -144,7 +144,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } // Enable this once fix SPARK-6700 - ignore("run Python application in yarn-cluster mode") { + test("run Python application in yarn-cluster mode") { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) val pyFile = new File(tempDir, "test2.py") @@ -282,10 +282,10 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } -private class SaveExecutorInfo extends SparkListener { +private[spark] class SaveExecutorInfo extends SparkListener { val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() - override def onExecutorAdded(executor : SparkListenerExecutorAdded) { + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfos(executor.executorId) = executor.executorInfo } } @@ -293,7 +293,6 @@ private class SaveExecutorInfo extends SparkListener { private object YarnClusterDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000 - var listener: SaveExecutorInfo = null def main(args: Array[String]): Unit = { if (args.length != 1) { @@ -306,10 +305,9 @@ private object YarnClusterDriver extends Logging with Matchers { System.exit(1) } - listener = new SaveExecutorInfo val sc = new SparkContext(new SparkConf() + .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) - sc.addSparkListener(listener) val status = new File(args(0)) var result = "failure" try { @@ -323,7 +321,12 @@ private object YarnClusterDriver extends Logging with Matchers { } // verify log urls are present - listener.addedExecutorInfos.values.foreach { info => + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) + val executorInfos = listener.addedExecutorInfos.values + assert(executorInfos.nonEmpty) + executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) } }