From ca43186867f0120c29d1b27cfee0c7ff4a107d84 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 7 May 2014 16:54:58 -0400 Subject: [PATCH 01/34] [SQL] Fix Performance Issue in data type casting Using lazy val object instead of function in the class Cast, which improved the performance nearly by 2X in my local micro-benchmark. Author: Cheng Hao Closes #679 from chenghao-intel/fix_type_casting and squashes the following commits: 71b0902 [Cheng Hao] using lazy val object instead of function for data type casting --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 40d2b42a0cda3..0b3a4e728ec54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -182,7 +182,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - def cast: Any => Any = dataType match { + private lazy val cast: Any => Any = dataType match { case StringType => castToString case BinaryType => castToBinary case DecimalType => castToDecimal From 7f6f4a1035ae0c9fa2029fe991f621ca263d53e0 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 7 May 2014 17:24:12 -0400 Subject: [PATCH 02/34] Nicer logging for SecurityManager startup Happy to open a jira ticket if you'd like to track one there. Author: Andrew Ash Closes #678 from ash211/SecurityManagerLogging and squashes the following commits: 2aa0b7a [Andrew Ash] Nicer logging for SecurityManager startup --- core/src/main/scala/org/apache/spark/SecurityManager.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index b4b0067801259..74aa441619bd2 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -146,8 +146,9 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) private val secretKey = generateSecretKey() - logInfo("SecurityManager, is authentication enabled: " + authOn + - " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) + logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + + "; ui acls " + (if (uiAclsOn) "enabled" else "disabled") + + "; users with view permissions: " + viewAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its From d00981a95185229fd1594d5c030a00f219fb1a14 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 7 May 2014 17:24:49 -0400 Subject: [PATCH 03/34] Typo fix: fetchting -> fetching Author: Andrew Ash Closes #680 from ash211/patch-3 and squashes the following commits: 9ce3746 [Andrew Ash] Typo fix: fetchting -> fetching --- make-distribution.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/make-distribution.sh b/make-distribution.sh index ebcd8c74fc5a6..759e555b4b69a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -189,7 +189,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` pushd $TMPD > /dev/null - echo "Fetchting tachyon tgz" + echo "Fetching tachyon tgz" wget "$TACHYON_URL" tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" From 5200872243aa5906dc8a06772e61d75f19557aac Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 7 May 2014 14:35:22 -0700 Subject: [PATCH 04/34] [SPARK-1688] Propagate PySpark worker stderr to driver When at least one of the following conditions is true, PySpark cannot be loaded: 1. PYTHONPATH is not set 2. PYTHONPATH does not contain the python directory (or jar, in the case of YARN) 3. The jar does not contain pyspark files (YARN) 4. The jar does not contain py4j files (YARN) However, we currently throw the same random `java.io.EOFException` for all of the above cases, when trying to read from the python daemon's output. This message is super unhelpful. This PR includes the python stderr and the PYTHONPATH in the exception propagated to the driver. Now, the exception message looks something like: ``` Error from python worker: : No module named pyspark PYTHONPATH was: /path/to/spark/python:/path/to/some/jar java.io.EOFException ``` whereas before it was just ``` java.io.EOFException ``` Author: Andrew Or Closes #603 from andrewor14/pyspark-exception and squashes the following commits: 10d65d3 [Andrew Or] Throwable -> Exception, worker -> daemon 862d1d7 [Andrew Or] Merge branch 'master' of github.com:apache/spark into pyspark-exception a5ed798 [Andrew Or] Use block string and interpolation instead of var (minor) cc09c45 [Andrew Or] Account for the fact that the python daemon may not have terminated yet 444f019 [Andrew Or] Use the new RedirectThread + include system PYTHONPATH aab00ae [Andrew Or] Merge branch 'master' of github.com:apache/spark into pyspark-exception 0cc2402 [Andrew Or] Merge branch 'master' of github.com:apache/spark into pyspark-exception 783efe2 [Andrew Or] Make python daemon stderr indentation consistent 9524172 [Andrew Or] Avoid potential NPE / error stream contention + Move things around 29f9688 [Andrew Or] Add back original exception type e92d36b [Andrew Or] Include python worker stderr in the exception propagated to the driver 7c69360 [Andrew Or] Merge branch 'master' of github.com:apache/spark into pyspark-exception cdbc185 [Andrew Or] Fix python attribute not found exception when PYTHONPATH is not set dcc0353 [Andrew Or] Check both python and system environment variables for PYTHONPATH 6c09c21 [Andrew Or] Validate PYTHONPATH and PySpark modules before starting python workers --- .../apache/spark/api/python/PythonUtils.scala | 27 +++- .../api/python/PythonWorkerFactory.scala | 136 ++++++++---------- .../apache/spark/deploy/PythonRunner.scala | 24 +--- .../scala/org/apache/spark/util/Utils.scala | 37 +++++ 4 files changed, 123 insertions(+), 101 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index cf69fa1d53fde..6d3e257c4d5df 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.File +import java.io.{File, InputStream, IOException, OutputStream} import scala.collection.mutable.ArrayBuffer @@ -40,3 +40,28 @@ private[spark] object PythonUtils { paths.filter(_ != "").mkString(File.pathSeparator) } } + + +/** + * A utility class to redirect the child process's stdout or stderr. + */ +private[spark] class RedirectThread( + in: InputStream, + out: OutputStream, + name: String) + extends Thread(name) { + + setDaemon(true) + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME: We copy the stream on the level of bytes to avoid encoding problems. + val buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + out.write(buf, 0, len) + out.flush() + len = in.read(buf) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index b0bf4e052b3e9..002f2acd94dee 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,15 +17,18 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, File, IOException, OutputStreamWriter} +import java.io.{DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import scala.collection.JavaConversions._ import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) - extends Logging { + extends Logging { + + import PythonWorkerFactory._ // Because forking processes from Java is expensive, we prefer to launch a single Python daemon // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently @@ -38,7 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemonPort: Int = 0 val pythonPath = PythonUtils.mergePythonPaths( - PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", "")) + PythonUtils.sparkPythonPath, + envVars.getOrElse("PYTHONPATH", ""), + sys.env.getOrElse("PYTHONPATH", "")) def create(): Socket = { if (useDaemon) { @@ -61,12 +66,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { new Socket(daemonHost, daemonPort) } catch { - case exc: SocketException => { + case exc: SocketException => logWarning("Python daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() new Socket(daemonHost, daemonPort) - } case e: Throwable => throw e } } @@ -87,39 +91,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) val worker = pb.start() - // Redirect the worker's stderr to ours - new Thread("stderr reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val in = worker.getErrorStream - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - - // Redirect worker's stdout to our stderr - new Thread("stdout reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val in = worker.getInputStream - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() + // Redirect worker stdout and stderr + redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) // Tell the worker our port val out = new OutputStreamWriter(worker.getOutputStream) @@ -142,10 +115,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String null } - def stop() { - stopDaemon() - } - private def startDaemon() { synchronized { // Is it already running? @@ -161,46 +130,38 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) daemon = pb.start() - // Redirect the stderr to ours - new Thread("stderr reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val in = daemon.getErrorStream - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - val in = new DataInputStream(daemon.getInputStream) daemonPort = in.readInt() - // Redirect further stdout output to our stderr - new Thread("stdout reader for " + pythonExec) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() + // Redirect daemon stdout and stderr + redirectStreamsToStderr(in, daemon.getErrorStream) + } catch { - case e: Throwable => { + case e: Exception => + + // If the daemon exists, wait for it to finish and get its stderr + val stderr = Option(daemon) + .flatMap { d => Utils.getStderr(d, PROCESS_WAIT_TIMEOUT_MS) } + .getOrElse("") + stopDaemon() - throw e - } + + if (stderr != "") { + val formattedStderr = stderr.replace("\n", "\n ") + val errorMessage = s""" + |Error from python worker: + | $formattedStderr + |PYTHONPATH was: + | $pythonPath + |$e""" + + // Append error message from python daemon, but keep original stack trace + val wrappedException = new SparkException(errorMessage.stripMargin) + wrappedException.setStackTrace(e.getStackTrace) + throw wrappedException + } else { + throw e + } } // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly @@ -208,6 +169,19 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Redirect the given streams to our stderr in separate threads. + */ + private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream) { + try { + new RedirectThread(stdout, System.err, "stdout reader for " + pythonExec).start() + new RedirectThread(stderr, System.err, "stderr reader for " + pythonExec).start() + } catch { + case e: Exception => + logError("Exception in redirecting streams", e) + } + } + private def stopDaemon() { synchronized { // Request shutdown of existing daemon by sending SIGTERM @@ -219,4 +193,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String daemonPort = 0 } } + + def stop() { + stopDaemon() + } +} + +private object PythonWorkerFactory { + val PROCESS_WAIT_TIMEOUT_MS = 10000 } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index f2e7c7a508b3f..e20d4486c8f0c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -17,13 +17,10 @@ package org.apache.spark.deploy -import java.io.{IOException, File, InputStream, OutputStream} - import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ -import org.apache.spark.SparkContext -import org.apache.spark.api.python.PythonUtils +import org.apache.spark.api.python.{PythonUtils, RedirectThread} /** * A main class used by spark-submit to launch Python applications. It executes python as a @@ -62,23 +59,4 @@ object PythonRunner { System.exit(process.waitFor()) } - - /** - * A utility class to redirect the child process's stdout or stderr - */ - class RedirectThread(in: InputStream, out: OutputStream, name: String) extends Thread(name) { - setDaemon(true) - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME: We copy the stream on the level of bytes to avoid encoding problems. - val buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - out.write(buf, 0, len) - out.flush() - len = in.read(buf) - } - } - } - } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 202bd46956f87..3f0ed61c5bbfb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1088,4 +1088,41 @@ private[spark] object Utils extends Logging { def stripDirectory(path: String): String = { path.split(File.separator).last } + + /** + * Wait for a process to terminate for at most the specified duration. + * Return whether the process actually terminated after the given timeout. + */ + def waitForProcess(process: Process, timeoutMs: Long): Boolean = { + var terminated = false + val startTime = System.currentTimeMillis + while (!terminated) { + try { + process.exitValue + terminated = true + } catch { + case e: IllegalThreadStateException => + // Process not terminated yet + if (System.currentTimeMillis - startTime > timeoutMs) { + return false + } + Thread.sleep(100) + } + } + true + } + + /** + * Return the stderr of a process after waiting for the process to terminate. + * If the process does not terminate within the specified timeout, return None. + */ + def getStderr(process: Process, timeoutMs: Long): Option[String] = { + val terminated = Utils.waitForProcess(process, timeoutMs) + if (terminated) { + Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n")) + } else { + None + } + } + } From 4bec84b6a23e1e642708a70a6c7ef7b3d1df9b3e Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 7 May 2014 15:51:53 -0700 Subject: [PATCH 05/34] SPARK-1569 Spark on Yarn, authentication broken by pr299 Pass the configs as java options since the executor needs to know before it registers whether to create the connection using authentication or not. We could see about passing only the authentication configs but for now I just had it pass them all. I also updating it to use a list to construct the command to make it the same as ClientBase and avoid any issues with spaces. Author: Thomas Graves Closes #649 from tgravescs/SPARK-1569 and squashes the following commits: 0178ab8 [Thomas Graves] add akka settings 22a8735 [Thomas Graves] Change to only path spark.auth* configs 8ccc1d4 [Thomas Graves] SPARK-1569 Spark on Yarn, authentication broken --- .../deploy/yarn/ExecutorRunnableUtil.scala | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 96f8aa93394f5..32f8861dc9503 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -21,7 +21,7 @@ import java.io.File import java.net.URI import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap, ListBuffer} import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api._ @@ -44,9 +44,9 @@ trait ExecutorRunnableUtil extends Logging { hostname: String, executorMemory: Int, executorCores: Int, - localResources: HashMap[String, LocalResource]) = { + localResources: HashMap[String, LocalResource]): List[String] = { // Extra options for the JVM - var JAVA_OPTS = "" + val JAVA_OPTS = ListBuffer[String]() // Set the JVM memory val executorMemoryString = executorMemory + "m" JAVA_OPTS += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " @@ -56,10 +56,21 @@ trait ExecutorRunnableUtil extends Logging { JAVA_OPTS += opts } - JAVA_OPTS += " -Djava.io.tmpdir=" + - new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " + JAVA_OPTS += "-Djava.io.tmpdir=" + + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources) + // Certain configs need to be passed here because they are needed before the Executor + // registers with the Scheduler and transfers the spark configs. Since the Executor backend + // uses Akka to connect to the scheduler, the akka settings are needed as well as the + // authentication settings. + sparkConf.getAll. + filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. + foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + + sparkConf.getAkkaConf. + foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" } + // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. // The context is, default gc for server class machines end up using all cores to do gc - hence @@ -85,25 +96,25 @@ trait ExecutorRunnableUtil extends Logging { } */ - val commands = List[String]( - Environment.JAVA_HOME.$() + "/bin/java" + - " -server " + + val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", + "-server", // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in // an inconsistent state. // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? - " -XX:OnOutOfMemoryError='kill %p' " + - JAVA_OPTS + - " org.apache.spark.executor.CoarseGrainedExecutorBackend " + - masterAddress + " " + - slaveId + " " + - hostname + " " + - executorCores + - " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + - " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - - commands + "-XX:OnOutOfMemoryError='kill %p'") ++ + JAVA_OPTS ++ + Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", + masterAddress.toString, + slaveId.toString, + hostname.toString, + executorCores.toString, + "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", + "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + + // TODO: it would be nicer to just make sure there are no null commands here + commands.map(s => if (s == null) "null" else s).toList } private def setupDistributedCache( From 3188553f73970270717a7fee4a116e29ad4becc9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 7 May 2014 16:01:11 -0700 Subject: [PATCH 06/34] [SPARK-1743][MLLIB] add loadLibSVMFile and saveAsLibSVMFile to pyspark Make loading/saving labeled data easier for pyspark users. Also changed type check in `SparseVector` to allow numpy integers. Author: Xiangrui Meng Closes #672 from mengxr/pyspark-mllib-util and squashes the following commits: 2943fa7 [Xiangrui Meng] format docs d61668d [Xiangrui Meng] add loadLibSVMFile and saveAsLibSVMFile to pyspark --- python/pyspark/mllib/linalg.py | 3 +- python/pyspark/mllib/util.py | 177 +++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 python/pyspark/mllib/util.py diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 0aa3a51de706b..7511ca7573ddb 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -49,8 +49,7 @@ def __init__(self, size, *args): >>> print SparseVector(4, [1, 3], [1.0, 5.5]) [1: 1.0, 3: 5.5] """ - assert type(size) == int, "first argument must be an int" - self.size = size + self.size = int(size) assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" if len(args) == 1: pairs = args[0] diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py new file mode 100644 index 0000000000000..50d0cdd087625 --- /dev/null +++ b/python/pyspark/mllib/util.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np + +from pyspark.mllib.linalg import Vectors, SparseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib._common import _convert_vector + +class MLUtils: + """ + Helper methods to load, save and pre-process data used in MLlib. + """ + + @staticmethod + def _parse_libsvm_line(line, multiclass): + """ + Parses a line in LIBSVM format into (label, indices, values). + """ + items = line.split(None) + label = float(items[0]) + if not multiclass: + label = 1.0 if label > 0.5 else 0.0 + nnz = len(items) - 1 + indices = np.zeros(nnz, dtype=np.int32) + values = np.zeros(nnz) + for i in xrange(nnz): + index, value = items[1 + i].split(":") + indices[i] = int(index) - 1 + values[i] = float(value) + return label, indices, values + + + @staticmethod + def _convert_labeled_point_to_libsvm(p): + """Converts a LabeledPoint to a string in LIBSVM format.""" + items = [str(p.label)] + v = _convert_vector(p.features) + if type(v) == np.ndarray: + for i in xrange(len(v)): + items.append(str(i + 1) + ":" + str(v[i])) + elif type(v) == SparseVector: + nnz = len(v.indices) + for i in xrange(nnz): + items.append(str(v.indices[i] + 1) + ":" + str(v.values[i])) + else: + raise TypeError("_convert_labeled_point_to_libsvm needs either ndarray or SparseVector" + " but got " % type(v)) + return " ".join(items) + + + @staticmethod + def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + """ + Loads labeled data in the LIBSVM format into an RDD of + LabeledPoint. The LIBSVM format is a text-based format used by + LIBSVM and LIBLINEAR. Each line represents a labeled sparse + feature vector using the following format: + + label index1:value1 index2:value2 ... + + where the indices are one-based and in ascending order. This + method parses each line into a LabeledPoint, where the feature + indices are converted to zero-based. + + @param sc: Spark context + @param path: file or directory path in any Hadoop-supported file + system URI + @param multiclass: whether the input labels contain more than + two classes. If false, any label with value + greater than 0.5 will be mapped to 1.0, or + 0.0 otherwise. So it works for both +1/-1 and + 1/0 cases. If true, the double value parsed + directly from the label string will be used + as the label value. + @param numFeatures: number of features, which will be determined + from the input data if a nonpositive value + is given. This is useful when the dataset is + already split into multiple files and you + want to load them separately, because some + features may not present in certain files, + which leads to inconsistent feature + dimensions. + @param minPartitions: min number of partitions + @return: labeled data stored as an RDD of LabeledPoint + + >>> from tempfile import NamedTemporaryFile + >>> from pyspark.mllib.util import MLUtils + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") + >>> tempFile.flush() + >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() + >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect() + >>> tempFile.close() + >>> examples[0].label + 1.0 + >>> examples[0].features.size + 6 + >>> print examples[0].features + [0: 1.0, 2: 2.0, 4: 3.0] + >>> examples[1].label + 0.0 + >>> examples[1].features.size + 6 + >>> print examples[1].features + [] + >>> examples[2].label + 0.0 + >>> examples[2].features.size + 6 + >>> print examples[2].features + [1: 4.0, 3: 5.0, 5: 6.0] + >>> multiclass_examples[1].label + -1.0 + """ + + lines = sc.textFile(path, minPartitions) + parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l, multiclass)) + if numFeatures <= 0: + parsed.cache() + numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 + return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) + + + @staticmethod + def saveAsLibSVMFile(data, dir): + """ + Save labeled data in LIBSVM format. + + @param data: an RDD of LabeledPoint to be saved + @param dir: directory to save the data + + >>> from tempfile import NamedTemporaryFile + >>> from fileinput import input + >>> from glob import glob + >>> from pyspark.mllib.util import MLUtils + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ + LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> MLUtils.saveAsLibSVMFile(sc.parallelize(examples), tempFile.name) + >>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*")))) + '0.0 1:1.01 2:2.02 3:3.03\\n1.1 1:1.23 3:4.56\\n' + """ + lines = data.map(lambda p: MLUtils._convert_labeled_point_to_libsvm(p)) + lines.saveAsTextFile(dir) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 0c19bb161b9b2b96c0c55d3ea09e81fd798cbec0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?baishuo=28=E7=99=BD=E7=A1=95=29?= Date: Wed, 7 May 2014 16:02:55 -0700 Subject: [PATCH 07/34] Update GradientDescentSuite.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit use more faster way to construct an array Author: baishuo(白硕) Closes #588 from baishuo/master and squashes the following commits: 45b95fb [baishuo(白硕)] Update GradientDescentSuite.scala c03b61c [baishuo(白硕)] Update GradientDescentSuite.scala b666d27 [baishuo(白硕)] Update GradientDescentSuite.scala --- .../spark/mllib/optimization/GradientDescentSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index c4b433499a091..8a16284118cf7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -81,11 +81,11 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Vectors.dense(1.0, features.toArray: _*) + label -> Vectors.dense(1.0 +: features.toArray) } val dataRDD = sc.parallelize(data, 2).cache() - val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*) + val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, @@ -111,7 +111,7 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Vectors.dense(1.0, features.toArray: _*) + label -> Vectors.dense(1.0 +: features.toArray) } val dataRDD = sc.parallelize(data, 2).cache() From f269b016acb17b24d106dc2b32a1be389489bb01 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 7 May 2014 17:08:38 -0700 Subject: [PATCH 08/34] SPARK-1544 Add support for deep decision trees. @etrain and I came with a PR for arbitrarily deep decision trees at the cost of multiple passes over the data at deep tree levels. To summarize: 1) We take a parameter that indicates the amount of memory users want to reserve for computation on each worker (and 2x that at the driver). 2) Using that information, we calculate two things - the maximum depth to which we train as usual (which is, implicitly, the maximum number of nodes we want to train in parallel), and the size of the groups we should use in the case where we exceed this depth. cc: @atalwalkar, @hirakendu, @mengxr Author: Manish Amde Author: manishamde Author: Evan Sparks Closes #475 from manishamde/deep_tree and squashes the following commits: 968ca9d [Manish Amde] merged master 7fc9545 [Manish Amde] added docs ce004a1 [Manish Amde] minor formatting b27ad2c [Manish Amde] formatting 426bb28 [Manish Amde] programming guide blurb 8053fed [Manish Amde] more formatting 5eca9e4 [Manish Amde] grammar 4731cda [Manish Amde] formatting 5e82202 [Manish Amde] added documentation, fixed off by 1 error in max level calculation cbd9f14 [Manish Amde] modified scala.math to math dad9652 [Manish Amde] removed unused imports e0426ee [Manish Amde] renamed parameter 718506b [Manish Amde] added unit test 1517155 [Manish Amde] updated documentation 9dbdabe [Manish Amde] merge from master 719d009 [Manish Amde] updating user documentation fecf89a [manishamde] Merge pull request #6 from etrain/deep_tree 0287772 [Evan Sparks] Fixing scalastyle issue. 2f1e093 [Manish Amde] minor: added doc for maxMemory parameter 2f6072c [manishamde] Merge pull request #5 from etrain/deep_tree abc5a23 [Evan Sparks] Parameterizing max memory. 50b143a [Manish Amde] adding support for very deep trees --- docs/mllib-decision-tree.md | 15 +-- .../examples/mllib/DecisionTreeRunner.scala | 2 +- .../spark/mllib/tree/DecisionTree.scala | 103 ++++++++++++++++-- .../mllib/tree/configuration/Strategy.scala | 6 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 84 ++++++++++++-- 5 files changed, 177 insertions(+), 33 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 296277e58b341..acf0feff42a8d 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -93,17 +93,14 @@ The recursive tree construction is stopped at a node when one of the two conditi 1. The node depth is equal to the `maxDepth` training parameter 2. No split candidate leads to an information gain at the node. +### Max memory requirements + +For faster processing, the decision tree algorithm performs simultaneous histogram computations for all nodes at each level of the tree. This could lead to high memory requirements at deeper levels of the tree leading to memory overflow errors. To alleviate this problem, a 'maxMemoryInMB' training parameter is provided which specifies the maximum amount of memory at the workers (twice as much at the master) to be allocated to the histogram computation. The default value is conservatively chosen to be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements for a level-wise computation crosses the `maxMemoryInMB` threshold, the node training tasks at each subsequent level is split into smaller tasks. + ### Practical limitations -1. The tree implementation stores an `Array[Double]` of size *O(#features \* #splits \* 2^maxDepth)* - in memory for aggregating histograms over partitions. The current implementation might not scale - to very deep trees since the memory requirement grows exponentially with tree depth. -2. The implemented algorithm reads both sparse and dense data. However, it is not optimized for - sparse input. -3. Python is not supported in this release. - -We are planning to solve these problems in the near future. Please drop us a line if you encounter -any issues. +1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. +2. Python is not supported in this release. ## Examples diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 0bd847d7bab30..9832bec90d7ee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -51,7 +51,7 @@ object DecisionTreeRunner { algo: Algo = Classification, maxDepth: Int = 5, impurity: ImpurityType = Gini, - maxBins: Int = 20) + maxBins: Int = 100) def main(args: Array[String]) { val defaultParams = Params() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 59ed01debf150..0fe30a3e7040b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -54,12 +54,13 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - logDebug("numSplits = " + bins(0).length) + val numBins = bins(0).length + logDebug("numBins = " + numBins) // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -68,7 +69,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) + // num features + val numFeatures = input.take(1)(0).features.size + + // Calculate level for single group construction + // Max memory usage for aggregates + val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + val numElementsPerNode = + strategy.algo match { + case Classification => 2 * numBins * numFeatures + case Regression => 3 * numBins * numFeatures + } + + logDebug("numElementsPerNode = " + numElementsPerNode) + val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array + val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) + logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) + // nodes at a level is 2^level. level is zero indexed. + val maxLevelForSingleGroup = math.max( + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) + logDebug("max level for single group = " + maxLevelForSingleGroup) /* * The main idea here is to perform level-wise training of the decision tree nodes thus @@ -88,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins) + level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -98,7 +120,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo filters) logDebug("final best split = " + nodeSplitStats._1) } - require(scala.math.pow(2, level) == splitsStatsForLevel.length) + require(math.pow(2, level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -109,6 +131,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } + logDebug("#####################################") + logDebug("Extracting tree model") + logDebug("#####################################") + // Initialize the top or root node of the tree. val topNode = nodes(0) // Build the full tree using the node info calculated in the level-wise best split calculations. @@ -127,7 +153,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val nodeIndex = math.pow(2, level).toInt - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -148,7 +174,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var i = 0 while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity @@ -249,7 +275,8 @@ object DecisionTree extends Serializable with Logging { private val InvalidBinIndex = -1 /** - * Returns an array of optimal splits for all nodes at a given level + * Returns an array of optimal splits for all nodes at a given level. Splits the task into + * multiple groups if the level-wise training task could lead to memory overflow. * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree @@ -260,6 +287,7 @@ object DecisionTree extends Serializable with Logging { * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features + * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( @@ -269,7 +297,57 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + bins: Array[Array[Bin]], + maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + // split into groups to avoid memory overflow during aggregation + if (level > maxLevelForSingleGroup) { + // When information for all nodes at a given level cannot be stored in memory, + // the nodes are divided into multiple groups at each level with the number of groups + // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, + // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. + val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + logDebug("numGroups = " + numGroups) + var bestSplits = new Array[(Split, InformationGainStats)](0) + // Iterate over each group of nodes at a level. + var groupIndex = 0 + while (groupIndex < numGroups) { + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, + filters, splits, bins, numGroups, groupIndex) + bestSplits = Array.concat(bestSplits, bestSplitsForGroup) + groupIndex += 1 + } + bestSplits + } else { + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + } + } + + /** + * Returns an array of optimal splits for a group of nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @param numGroups total number of node groups at the current level. Default value is set to 1. + * @param groupIndex index of the node group being processed. Default value is set to 0. + * @return array of splits with best splits for all nodes at a given level. + */ + private def findBestSplitsPerGroup( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]], + numGroups: Int = 1, + groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* * The high-level description for the best split optimizations are noted here. @@ -296,7 +374,7 @@ object DecisionTree extends Serializable with Logging { */ // common calculations for multiple nested methods - val numNodes = scala.math.pow(2, level).toInt + val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size @@ -304,12 +382,15 @@ object DecisionTree extends Serializable with Logging { val numBins = bins(0).length logDebug("numBins = " + numBins) + // shift when more than one group is used at deep tree level + val groupShift = numNodes * groupIndex + /** Find the filters used before reaching the current code. */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift filters(nodeFilterIndex) } } @@ -878,7 +959,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 8767aca47cd5a..1b505fd76eb75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -35,6 +35,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + * */ @Experimental class Strategy ( @@ -43,4 +46,5 @@ class Strategy ( val maxDepth: Int, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val maxMemoryInMB: Int = 128) extends Serializable diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index be383aab714d3..35e92d71dc63f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -22,7 +22,8 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors @@ -242,7 +243,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -269,7 +270,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -298,7 +299,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -321,7 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -345,7 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -369,7 +370,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -378,13 +379,60 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.rightImpurity === 0) assert(bestSplits(0)._2.predict === 1) } + + test("test second level node building with/without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) + val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + val parentImpurities = Array(0.5, 0.5, 0.5) + + // Single group second level tree construction. + val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + splits, bins, 10) + assert(bestSplits.length === 2) + assert(bestSplits(0)._2.gain > 0) + assert(bestSplits(1)._2.gain > 0) + + // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second + // level tree construction. + val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + filters, splits, bins, 0) + assert(bestSplitsWithGroups.length === 2) + assert(bestSplitsWithGroups(0)._2.gain > 0) + assert(bestSplitsWithGroups(1)._2.gain > 0) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until bestSplits.length) { + assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) + assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) + assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) + assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) + assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) + assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + } + + } + } object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } @@ -393,17 +441,31 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } arr } + def generateOrderedLabeledPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000) { + if (i < 600) { + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } else { + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } + } + arr + } + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ - if (i < 600){ + for (i <- 0 until 1000) { + if (i < 600) { arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) From 108c4c16cc82af2e161d569d2c23849bdbf4aadb Mon Sep 17 00:00:00 2001 From: Sandeep Date: Thu, 8 May 2014 00:15:05 -0400 Subject: [PATCH 09/34] SPARK-1668: Add implicit preference as an option to examples/MovieLensALS Add --implicitPrefs as an command-line option to the example app MovieLensALS under examples/ Author: Sandeep Closes #597 from techaddict/SPARK-1668 and squashes the following commits: 8b371dc [Sandeep] Second Pass on reviews by mengxr eca9d37 [Sandeep] based on mengxr's suggestions 937e54c [Sandeep] Changes 5149d40 [Sandeep] Changes based on review 1dd7657 [Sandeep] use mean() 42444d7 [Sandeep] Based on Suggestions by mengxr e3082fa [Sandeep] SPARK-1668: Add implicit preference as an option to examples/MovieLensALS Add --implicitPrefs as an command-line option to the example app MovieLensALS under examples/ --- .../spark/examples/mllib/MovieLensALS.scala | 55 ++++++++++++++++--- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 703f02255b94b..0e4447e0de24f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -43,7 +43,8 @@ object MovieLensALS { kryo: Boolean = false, numIterations: Int = 20, lambda: Double = 1.0, - rank: Int = 10) + rank: Int = 10, + implicitPrefs: Boolean = false) def main(args: Array[String]) { val defaultParams = Params() @@ -62,6 +63,9 @@ object MovieLensALS { opt[Unit]("kryo") .text(s"use Kryo serialization") .action((_, c) => c.copy(kryo = true)) + opt[Unit]("implicitPrefs") + .text("use implicit preference") + .action((_, c) => c.copy(implicitPrefs = true)) arg[String]("") .required() .text("input paths to a MovieLens dataset of ratings") @@ -88,7 +92,25 @@ object MovieLensALS { val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) + if (params.implicitPrefs) { + /* + * MovieLens ratings are on a scale of 1-5: + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + * So we should not recommend a movie if the predicted rating is less than 3. + * To map ratings to confidence scores, we use + * 5 -> 2.5, 4 -> 1.5, 3 -> 0.5, 2 -> -0.5, 1 -> -1.5. This mappings means unobserved + * entries are generally between It's okay and Fairly bad. + * The semantics of 0 in this expanded world of non-positive weights + * are "the same as never having interacted at all". + */ + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + } else { + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) + } }.cache() val numRatings = ratings.count() @@ -99,7 +121,18 @@ object MovieLensALS { val splits = ratings.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() - val test = splits(1).cache() + val test = if (params.implicitPrefs) { + /* + * 0 means "don't know" and positive values mean "confident that the prediction should be 1". + * Negative values means "confident that the prediction should be 0". + * We have in this case used some kind of weighted RMSE. The weight is the absolute value of + * the confidence. The error is the difference between prediction and either 1 or 0, + * depending on whether r is positive or negative. + */ + splits(1).map(x => Rating(x.user, x.product, if (x.rating > 0) 1.0 else 0.0)) + } else { + splits(1) + }.cache() val numTraining = training.count() val numTest = test.count() @@ -111,9 +144,10 @@ object MovieLensALS { .setRank(params.rank) .setIterations(params.numIterations) .setLambda(params.lambda) + .setImplicitPrefs(params.implicitPrefs) .run(training) - val rmse = computeRmse(model, test, numTest) + val rmse = computeRmse(model, test, params.implicitPrefs) println(s"Test RMSE = $rmse.") @@ -121,11 +155,14 @@ object MovieLensALS { } /** Compute RMSE (Root Mean Squared Error). */ - def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], n: Long) = { + def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = { + + def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r + val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) - val predictionsAndRatings = predictions.map(x => ((x.user, x.product), x.rating)) - .join(data.map(x => ((x.user, x.product), x.rating))) - .values - math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).reduce(_ + _) / n) + val predictionsAndRatings = predictions.map{ x => + ((x.user, x.product), mapPredictedRating(x.rating)) + }.join(data.map(x => ((x.user, x.product), x.rating))).values + math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean()) } } From 6ed7e2cd01955adfbb3960e2986b6d19eaee8717 Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Thu, 8 May 2014 00:24:36 -0400 Subject: [PATCH 10/34] Use numpy directly for matrix multiply. Using matrix multiply to compute XtX and XtY yields a 5-20x speedup depending on problem size. For example - the following takes 19s locally after this change vs. 5m21s before the change. (16x speedup). bin/pyspark examples/src/main/python/als.py local[8] 1000 1000 50 10 10 Author: Evan Sparks Closes #687 from etrain/patch-1 and squashes the following commits: e094dbc [Evan Sparks] Touching only diaganols on update. d1ab9b6 [Evan Sparks] Use numpy directly for matrix multiply. --- examples/src/main/python/als.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index a77dfb2577835..33700ab4f8c53 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,14 +36,13 @@ def rmse(R, ms, us): def update(i, vec, mat, ratings): uu = mat.shape[0] ff = mat.shape[1] - XtX = matrix(np.zeros((ff, ff))) - Xty = np.zeros((ff, 1)) - - for j in range(uu): - v = mat[j, :] - XtX += v.T * v - Xty += v.T * ratings[i, j] - XtX += np.eye(ff, ff) * LAMBDA * uu + + XtX = mat.T * mat + XtY = mat.T * ratings[i, :].T + + for j in range(ff): + XtX[j,j] += LAMBDA * uu + return np.linalg.solve(XtX, Xty) if __name__ == "__main__": From 19c8fb02bc2c2f76c3c45bfff4b8d093be9d7c66 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 8 May 2014 01:08:43 -0400 Subject: [PATCH 11/34] [SQL] Improve SparkSQL Aggregates * Add native min/max (was using hive before). * Handle nulls correctly in Avg and Sum. Author: Michael Armbrust Closes #683 from marmbrus/aggFixes and squashes the following commits: 64fe30b [Michael Armbrust] Improve SparkSQL Aggregates * Add native min/max (was using hive before). * Handle nulls correctly in Avg and Sum. --- .../apache/spark/sql/catalyst/SqlParser.scala | 4 + .../sql/catalyst/expressions/aggregates.scala | 85 ++++++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 7 ++ .../scala/org/apache/spark/sql/TestData.scala | 10 +++ 4 files changed, 96 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 8c76a3aa96546..b3a3a1ef1b5eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val JOIN = Keyword("JOIN") protected val LEFT = Keyword("LEFT") protected val LIMIT = Keyword("LIMIT") + protected val MAX = Keyword("MAX") + protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") @@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | + MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | + MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } | IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case c ~ "," ~ t ~ "," ~ f => If(c,t,f) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index b152f95f96c70..7777d372903e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -86,6 +86,67 @@ abstract class AggregateFunction override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } +case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = child.nullable + override def dataType = child.dataType + override def toString = s"MIN($child)" + + override def asPartial: SplitEvaluation = { + val partialMin = Alias(Min(child), "PartialMin")() + SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) + } + + override def newInstance() = new MinFunction(child, this) +} + +case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var currentMin: Any = _ + + override def update(input: Row): Unit = { + if (currentMin == null) { + currentMin = expr.eval(input) + } else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) { + currentMin = expr.eval(input) + } + } + + override def eval(input: Row): Any = currentMin +} + +case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references = child.references + override def nullable = child.nullable + override def dataType = child.dataType + override def toString = s"MAX($child)" + + override def asPartial: SplitEvaluation = { + val partialMax = Alias(Max(child), "PartialMax")() + SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) + } + + override def newInstance() = new MaxFunction(child, this) +} + +case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var currentMax: Any = _ + + override def update(input: Row): Unit = { + if (currentMax == null) { + currentMax = expr.eval(input) + } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) { + currentMax = expr.eval(input) + } + } + + override def eval(input: Row): Any = currentMax +} + + case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def references = child.references override def nullable = false @@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) } - override def newInstance()= new CountFunction(child, this) + override def newInstance() = new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { @@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi override def nullable = false override def dataType = IntegerType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})" - override def newInstance()= new CountDistinctFunction(expressions, this) + override def newInstance() = new CountDistinctFunction(expressions, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -126,7 +187,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN partialCount :: partialSum :: Nil) } - override def newInstance()= new AverageFunction(child, this) + override def newInstance() = new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ partialSum :: Nil) } - override def newInstance()= new SumFunction(child, this) + override def newInstance() = new SumFunction(child, this) } case class SumDistinct(child: Expression) @@ -153,7 +214,7 @@ case class SumDistinct(child: Expression) override def dataType = child.dataType override def toString = s"SUM(DISTINCT $child)" - override def newInstance()= new SumDistinctFunction(child, this) + override def newInstance() = new SumDistinctFunction(child, this) } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -168,7 +229,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance()= new FirstFunction(child, this) + override def newInstance() = new FirstFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -176,11 +237,13 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. + private val zero = Cast(Literal(0), expr.dataType) + private var count: Long = _ - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow)) + private val sum = MutableLiteral(zero.eval(EmptyRow)) private val sumAsDouble = Cast(sum, DoubleType) - private val addFunction = Add(sum, expr) + private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) override def eval(input: Row): Any = sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble @@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null)) + private val zero = Cast(Literal(0), expr.dataType) + + private val sum = MutableLiteral(zero.eval(null)) - private val addFunction = Add(sum, expr) + private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) override def update(input: Row): Unit = { sum.update(addFunction, input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dde957d715a28..e966d89c30cf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -50,6 +50,13 @@ class SQLQuerySuite extends QueryTest { Seq((1,3),(2,3),(3,3))) } + test("aggregates with nulls") { + checkAnswer( + sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), + (1, 3, 2, 6, 3) :: Nil + ) + } + test("select *") { checkAnswer( sql("SELECT * FROM testData"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index b5973c0f51be8..aa71e274f7f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -84,4 +84,14 @@ object TestData { List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) nullableRepeatedData.registerAsTable("nullableRepeatedData") + + case class NullInts(a: Integer) + val nullInts = + TestSQLContext.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil + ) + nullInts.registerAsTable("nullInts") } From 44dd57fb66bb676d753ad8d9757f9f4c03364113 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 8 May 2014 10:23:05 -0700 Subject: [PATCH 12/34] SPARK-1565, update examples to be used with spark-submit script. Commit for initial feedback, basically I am curious if we should prompt user for providing args esp. when its mandatory. And can we skip if they are not ? Also few other things that did not work like `bin/spark-submit examples/target/scala-2.10/spark-examples-1.0.0-SNAPSHOT-hadoop1.0.4.jar --class org.apache.spark.examples.SparkALS --arg 100 500 10 5 2` Not all the args get passed properly, may be I have messed up something will try to sort it out hopefully. Author: Prashant Sharma Closes #552 from ScrapCodes/SPARK-1565/update-examples and squashes the following commits: 669dd23 [Prashant Sharma] Review comments 2727e70 [Prashant Sharma] SPARK-1565, update examples to be used with spark-submit script. --- .gitignore | 1 + .../scala/org/apache/spark/SparkContext.scala | 8 ++-- .../org/apache/spark/examples/JavaHdfsLR.java | 13 ++++--- .../apache/spark/examples/JavaLogQuery.java | 13 +++---- .../apache/spark/examples/JavaPageRank.java | 15 +++++--- .../apache/spark/examples/JavaSparkPi.java | 18 ++++----- .../org/apache/spark/examples/JavaTC.java | 24 ++++++------ .../apache/spark/examples/JavaWordCount.java | 12 +++--- .../apache/spark/examples/mllib/JavaALS.java | 22 +++++------ .../spark/examples/mllib/JavaKMeans.java | 22 +++++------ .../apache/spark/examples/mllib/JavaLR.java | 18 ++++----- .../spark/examples/sql/JavaSparkSQL.java | 5 ++- .../streaming/JavaFlumeEventCount.java | 19 ++++------ .../streaming/JavaKafkaWordCount.java | 27 +++++++------- .../streaming/JavaNetworkWordCount.java | 25 ++++++------- .../examples/streaming/JavaQueueStream.java | 22 +++++------ .../apache/spark/examples/BroadcastTest.scala | 22 +++++------ .../spark/examples/CassandraCQLTest.scala | 19 +++++----- .../apache/spark/examples/CassandraTest.scala | 10 ++--- .../examples/ExceptionHandlingTest.scala | 11 ++---- .../apache/spark/examples/GroupByTest.scala | 25 ++++++------- .../org/apache/spark/examples/HBaseTest.scala | 6 +-- .../org/apache/spark/examples/HdfsTest.scala | 4 +- .../org/apache/spark/examples/LogQuery.scala | 14 +++---- .../spark/examples/MultiBroadcastTest.scala | 17 ++++----- .../examples/SimpleSkewedGroupByTest.scala | 24 ++++++------ .../spark/examples/SkewedGroupByTest.scala | 25 ++++++------- .../org/apache/spark/examples/SparkALS.scala | 18 +++------ .../apache/spark/examples/SparkHdfsLR.scala | 13 ++++--- .../apache/spark/examples/SparkKMeans.scala | 18 ++++----- .../org/apache/spark/examples/SparkLR.scala | 11 ++---- .../apache/spark/examples/SparkPageRank.scala | 14 +++---- .../org/apache/spark/examples/SparkPi.scala | 10 ++--- .../org/apache/spark/examples/SparkTC.scala | 12 ++---- .../spark/examples/SparkTachyonHdfsLR.scala | 12 ++---- .../spark/examples/SparkTachyonPi.scala | 10 ++--- .../examples/bagel/WikipediaPageRank.scala | 10 ++--- .../bagel/WikipediaPageRankStandalone.scala | 10 ++--- .../examples/graphx/LiveJournalPageRank.scala | 6 +-- .../spark/examples/sql/RDDRelation.scala | 5 ++- .../examples/sql/hive/HiveFromSpark.scala | 5 ++- .../examples/streaming/ActorWordCount.scala | 21 +++++------ .../examples/streaming/FlumeEventCount.scala | 14 +++---- .../examples/streaming/HdfsWordCount.scala | 18 ++++----- .../examples/streaming/KafkaWordCount.scala | 21 +++++------ .../examples/streaming/MQTTWordCount.scala | 26 ++++++------- .../examples/streaming/NetworkWordCount.scala | 23 +++++------- .../examples/streaming/QueueStream.scala | 10 ++--- .../examples/streaming/RawNetworkGrep.scala | 16 ++++---- .../RecoverableNetworkWordCount.scala | 37 ++++++++++--------- .../streaming/StatefulNetworkWordCount.scala | 21 +++++------ .../streaming/TwitterAlgebirdCMS.scala | 15 +++----- .../streaming/TwitterAlgebirdHLL.scala | 14 +++---- .../streaming/TwitterPopularTags.scala | 13 ++----- .../examples/streaming/ZeroMQWordCount.scala | 23 ++++++------ .../apache/spark/graphx/lib/Analytics.scala | 18 +++++---- 56 files changed, 405 insertions(+), 480 deletions(-) diff --git a/.gitignore b/.gitignore index 32b603f1bc84f..ad72588b472d6 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ unit-tests.log /lib/ rat-results.txt scalastyle.txt +conf/*.conf # For Hive metastore_db/ diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index eb14d87467af7..9d7c2c8d3d630 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -74,10 +74,10 @@ class SparkContext(config: SparkConf) extends Logging { * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] * from a list of input files or InputFormats for the application. */ - @DeveloperApi - def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { - this(config) - this.preferredNodeLocationData = preferredNodeLocationData + @DeveloperApi + def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { + this(config) + this.preferredNodeLocationData = preferredNodeLocationData } /** diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index bd96274021756..6c177de359b60 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -17,6 +17,7 @@ package org.apache.spark.examples; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -103,16 +104,16 @@ public static void printWeights(double[] a) { public static void main(String[] args) { - if (args.length < 3) { - System.err.println("Usage: JavaHdfsLR "); + if (args.length < 2) { + System.err.println("Usage: JavaHdfsLR "); System.exit(1); } - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaHdfsLR.class)); - JavaRDD lines = sc.textFile(args[1]); + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaRDD lines = sc.textFile(args[0]); JavaRDD points = lines.map(new ParsePoint()).cache(); - int ITERATIONS = Integer.parseInt(args[2]); + int ITERATIONS = Integer.parseInt(args[1]); // Initialize w to a random value double[] w = new double[D]; diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 3f7a879538016..812e9d5580cbf 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -20,6 +20,7 @@ import com.google.common.collect.Lists; import scala.Tuple2; import scala.Tuple3; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -34,6 +35,8 @@ /** * Executes a roll up-style query against Apache logs. + * + * Usage: JavaLogQuery [logFile] */ public final class JavaLogQuery { @@ -97,15 +100,11 @@ public static Stats extractStats(String line) { } public static void main(String[] args) { - if (args.length == 0) { - System.err.println("Usage: JavaLogQuery [logFile]"); - System.exit(1); - } - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaLogQuery.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaLogQuery"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); - JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs); + JavaRDD dataSet = (args.length == 1) ? jsc.textFile(args[0]) : jsc.parallelize(exampleApacheLogs); JavaPairRDD, Stats> extracted = dataSet.mapToPair(new PairFunction, Stats>() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index e31f676f5fd4c..7ea6df9c17245 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -18,9 +18,12 @@ package org.apache.spark.examples; + import scala.Tuple2; import com.google.common.collect.Iterables; + +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -54,20 +57,20 @@ public Double call(Double a, Double b) { } public static void main(String[] args) throws Exception { - if (args.length < 3) { - System.err.println("Usage: JavaPageRank "); + if (args.length < 2) { + System.err.println("Usage: JavaPageRank "); System.exit(1); } - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaPageRank", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaPageRank.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); + JavaSparkContext ctx = new JavaSparkContext(sparkConf); // Loads in input file. It should be in format of: // URL neighbor URL // URL neighbor URL // URL neighbor URL // ... - JavaRDD lines = ctx.textFile(args[1], 1); + JavaRDD lines = ctx.textFile(args[0], 1); // Loads all URLs from input file and initialize their neighbors. JavaPairRDD> links = lines.mapToPair(new PairFunction() { @@ -87,7 +90,7 @@ public Double call(Iterable rs) { }); // Calculates and updates URL ranks continuously using PageRank algorithm. - for (int current = 0; current < Integer.parseInt(args[2]); current++) { + for (int current = 0; current < Integer.parseInt(args[1]); current++) { // Calculates URL contributions to the rank of other URLs. JavaPairRDD contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index ac8df02c4630b..11157d7573fae 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -17,6 +17,7 @@ package org.apache.spark.examples; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -25,19 +26,18 @@ import java.util.ArrayList; import java.util.List; -/** Computes an approximation to pi */ +/** + * Computes an approximation to pi + * Usage: JavaSparkPi [slices] + */ public final class JavaSparkPi { + public static void main(String[] args) throws Exception { - if (args.length == 0) { - System.err.println("Usage: JavaSparkPi [slices]"); - System.exit(1); - } - - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaSparkPi", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaSparkPi.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); - int slices = (args.length == 2) ? Integer.parseInt(args[1]) : 2; + int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2; int n = 100000 * slices; List l = new ArrayList(n); for (int i = 0; i < n; i++) { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index d66b9ba265fe8..2563fcdd234bb 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -17,19 +17,22 @@ package org.apache.spark.examples; -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.PairFunction; - import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Random; import java.util.Set; +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.PairFunction; + /** * Transitive closure on a graph, implemented in Java. + * Usage: JavaTC [slices] */ public final class JavaTC { @@ -61,14 +64,9 @@ public Tuple2 call(Tuple2> t } public static void main(String[] args) { - if (args.length == 0) { - System.err.println("Usage: JavaTC []"); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaTC", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaTC.class)); - Integer slices = (args.length > 1) ? Integer.parseInt(args[1]): 2; + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + Integer slices = (args.length > 0) ? Integer.parseInt(args[0]): 2; JavaPairRDD tc = sc.parallelizePairs(generateGraph(), slices).cache(); // Linear transitive closure: each round grows paths by one edge, diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 87c1b80981961..9a6a944f7edef 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -18,6 +18,7 @@ package org.apache.spark.examples; import scala.Tuple2; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -33,14 +34,15 @@ public final class JavaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaWordCount "); + + if (args.length < 1) { + System.err.println("Usage: JavaWordCount "); System.exit(1); } - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaWordCount.class)); - JavaRDD lines = ctx.textFile(args[1], 1); + SparkConf sparkConf = new SparkConf().setAppName("JavaWordCount"); + JavaSparkContext ctx = new JavaSparkContext(sparkConf); + JavaRDD lines = ctx.textFile(args[0], 1); JavaRDD words = lines.flatMap(new FlatMapFunction() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java index 4533c4c5f241a..8d381d4e0a943 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java @@ -17,6 +17,7 @@ package org.apache.spark.examples.mllib; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -57,23 +58,22 @@ public String call(Tuple2 element) { public static void main(String[] args) { - if (args.length != 5 && args.length != 6) { + if (args.length < 4) { System.err.println( - "Usage: JavaALS []"); + "Usage: JavaALS []"); System.exit(1); } - - int rank = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); - String outputDir = args[4]; + SparkConf sparkConf = new SparkConf().setAppName("JavaALS"); + int rank = Integer.parseInt(args[1]); + int iterations = Integer.parseInt(args[2]); + String outputDir = args[3]; int blocks = -1; - if (args.length == 6) { - blocks = Integer.parseInt(args[5]); + if (args.length == 5) { + blocks = Integer.parseInt(args[4]); } - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaALS.class)); - JavaRDD lines = sc.textFile(args[1]); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaRDD lines = sc.textFile(args[0]); JavaRDD ratings = lines.map(new ParseRating()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java index 0cfb8e69ed28f..f796123a25727 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java @@ -19,6 +19,7 @@ import java.util.regex.Pattern; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -48,24 +49,21 @@ public Vector call(String line) { } public static void main(String[] args) { - - if (args.length < 4) { + if (args.length < 3) { System.err.println( - "Usage: JavaKMeans []"); + "Usage: JavaKMeans []"); System.exit(1); } - - String inputFile = args[1]; - int k = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); + String inputFile = args[0]; + int k = Integer.parseInt(args[1]); + int iterations = Integer.parseInt(args[2]); int runs = 1; - if (args.length >= 5) { - runs = Integer.parseInt(args[4]); + if (args.length >= 4) { + runs = Integer.parseInt(args[3]); } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD lines = sc.textFile(inputFile); JavaRDD points = lines.map(new ParsePoint()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java index f6e48b498727b..eceb6927d5551 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java @@ -19,6 +19,7 @@ import java.util.regex.Pattern; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -51,17 +52,16 @@ public LabeledPoint call(String line) { } public static void main(String[] args) { - if (args.length != 4) { - System.err.println("Usage: JavaLR "); + if (args.length != 3) { + System.err.println("Usage: JavaLR "); System.exit(1); } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaLR.class)); - JavaRDD lines = sc.textFile(args[1]); + SparkConf sparkConf = new SparkConf().setAppName("JavaLR"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaRDD lines = sc.textFile(args[0]); JavaRDD points = lines.map(new ParsePoint()).cache(); - double stepSize = Double.parseDouble(args[2]); - int iterations = Integer.parseInt(args[3]); + double stepSize = Double.parseDouble(args[1]); + int iterations = Integer.parseInt(args[2]); // Another way to configure LogisticRegression // @@ -73,7 +73,7 @@ public static void main(String[] args) { // LogisticRegressionModel model = lr.train(points.rdd()); LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), - iterations, stepSize); + iterations, stepSize); System.out.print("Final w: " + model.weights()); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index d62a72f53443c..ad5ec84b71e69 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -20,6 +20,7 @@ import java.io.Serializable; import java.util.List; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -51,8 +52,8 @@ public void setAge(int age) { } public static void main(String[] args) throws Exception { - JavaSparkContext ctx = new JavaSparkContext("local", "JavaSparkSQL", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaSparkSQL.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); + JavaSparkContext ctx = new JavaSparkContext(sparkConf); JavaSQLContext sqlCtx = new JavaSQLContext(ctx); // Load a text file and convert each line to a Java Bean. diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java index a5ece68cef870..400b68c2215b3 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java @@ -17,6 +17,7 @@ package org.apache.spark.examples.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.Function; import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.*; @@ -31,9 +32,8 @@ * an Avro server on at the request host:port address and listen for requests. * Your Flume AvroSink should be pointed to this address. * - * Usage: JavaFlumeEventCount + * Usage: JavaFlumeEventCount * - * is a Spark master URL * is the host the Flume receiver will be started on - a receiver * creates a server and listens for flume events. * is the port the Flume receiver will listen on. @@ -43,22 +43,19 @@ private JavaFlumeEventCount() { } public static void main(String[] args) { - if (args.length != 3) { - System.err.println("Usage: JavaFlumeEventCount "); + if (args.length != 2) { + System.err.println("Usage: JavaFlumeEventCount "); System.exit(1); } StreamingExamples.setStreamingLogLevels(); - String master = args[0]; - String host = args[1]; - int port = Integer.parseInt(args[2]); + String host = args[0]; + int port = Integer.parseInt(args[1]); Duration batchInterval = new Duration(2000); - - JavaStreamingContext ssc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), - JavaStreamingContext.jarOfClass(JavaFlumeEventCount.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval); JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(ssc, "localhost", port); flumeStream.count(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index da51eb189a649..6a74cc50d19ed 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -21,7 +21,11 @@ import java.util.HashMap; import java.util.regex.Pattern; + +import scala.Tuple2; + import com.google.common.collect.Lists; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; @@ -33,19 +37,18 @@ import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.kafka.KafkaUtils; -import scala.Tuple2; /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaKafkaWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: JavaKafkaWordCount * is a list of one or more zookeeper servers that make quorum * is the name of kafka consumer group * is a list of one or more kafka topics to consume from * is the number of threads the kafka consumer should use * * Example: - * `./bin/run-example org.apache.spark.examples.streaming.JavaKafkaWordCount local[2] zoo01,zoo02, + * `./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.JavaKafkaWordCount zoo01,zoo02, \ * zoo03 my-consumer-group topic1,topic2 1` */ @@ -56,27 +59,25 @@ private JavaKafkaWordCount() { } public static void main(String[] args) { - if (args.length < 5) { - System.err.println("Usage: KafkaWordCount "); + if (args.length < 4) { + System.err.println("Usage: JavaKafkaWordCount "); System.exit(1); } StreamingExamples.setStreamingLogLevels(); - + SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount"); // Create the context with a 1 second batch size - JavaStreamingContext jssc = new JavaStreamingContext(args[0], "KafkaWordCount", - new Duration(2000), System.getenv("SPARK_HOME"), - JavaStreamingContext.jarOfClass(JavaKafkaWordCount.class)); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); - int numThreads = Integer.parseInt(args[4]); + int numThreads = Integer.parseInt(args[3]); Map topicMap = new HashMap(); - String[] topics = args[3].split(","); + String[] topics = args[2].split(","); for (String topic: topics) { topicMap.put(topic, numThreads); } JavaPairReceiverInputDStream messages = - KafkaUtils.createStream(jssc, args[1], args[2], topicMap); + KafkaUtils.createStream(jssc, args[0], args[1], topicMap); JavaDStream lines = messages.map(new Function, String>() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index ac84991d87b8b..e5cbd39f437c2 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -17,9 +17,10 @@ package org.apache.spark.examples.streaming; -import com.google.common.collect.Lists; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import scala.Tuple2; +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; @@ -27,41 +28,39 @@ import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import java.util.regex.Pattern; /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: JavaNetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: JavaNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example - * `$ ./run org.apache.spark.examples.streaming.JavaNetworkWordCount local[2] localhost 9999` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.JavaNetworkWordCount localhost 9999` */ public final class JavaNetworkWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) { - if (args.length < 3) { - System.err.println("Usage: JavaNetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1"); + if (args.length < 2) { + System.err.println("Usage: JavaNetworkWordCount "); System.exit(1); } StreamingExamples.setStreamingLogLevels(); - + SparkConf sparkConf = new SparkConf().setAppName("JavaNetworkWordCount"); // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount", - new Duration(1000), System.getenv("SPARK_HOME"), - JavaStreamingContext.jarOfClass(JavaNetworkWordCount.class)); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000)); // Create a JavaReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') - JavaReceiverInputDStream lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); + JavaReceiverInputDStream lines = ssc.socketTextStream(args[0], Integer.parseInt(args[1])); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java index 819311968fac5..4ce8437f82705 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java @@ -17,8 +17,16 @@ package org.apache.spark.examples.streaming; -import com.google.common.collect.Lists; + +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + import scala.Tuple2; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; @@ -28,25 +36,17 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import java.util.LinkedList; -import java.util.List; -import java.util.Queue; - public final class JavaQueueStream { private JavaQueueStream() { } public static void main(String[] args) throws Exception { - if (args.length < 1) { - System.err.println("Usage: JavaQueueStream "); - System.exit(1); - } StreamingExamples.setStreamingLogLevels(); + SparkConf sparkConf = new SparkConf().setAppName("JavaQueueStream"); // Create the context - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000), - System.getenv("SPARK_HOME"), JavaStreamingContext.jarOfClass(JavaQueueStream.class)); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000)); // Create the queue through which RDDs can be pushed to // a QueueInputDStream diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index f6dfd2c4c6217..973049b95a7bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -17,28 +17,26 @@ package org.apache.spark.examples -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} +/** + * Usage: BroadcastTest [slices] [numElem] [broadcastAlgo] [blockSize] + */ object BroadcastTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: BroadcastTest [slices] [numElem] [broadcastAlgo]" + - " [blockSize]") - System.exit(1) - } - val bcName = if (args.length > 3) args(3) else "Http" - val blockSize = if (args.length > 4) args(4) else "4096" + val bcName = if (args.length > 2) args(2) else "Http" + val blockSize = if (args.length > 3) args(3) else "4096" System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory") System.setProperty("spark.broadcast.blockSize", blockSize) + val sparkConf = new SparkConf().setAppName("Broadcast Test") - val sc = new SparkContext(args(0), "Broadcast Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sc = new SparkContext(sparkConf) - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 + val slices = if (args.length > 0) args(0).toInt else 2 + val num = if (args.length > 1) args(1).toInt else 1000000 val arr1 = new Array[Int](num) for (i <- 0 until arr1.length) { diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 3798329fc2f41..9a00701f985f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -30,7 +30,7 @@ import org.apache.cassandra.hadoop.cql3.CqlOutputFormat import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /* @@ -65,19 +65,18 @@ import org.apache.spark.SparkContext._ /** * This example demonstrates how to read and write to cassandra column family created using CQL3 * using Spark. - * Parameters : - * Usage: ./bin/run-example org.apache.spark.examples.CassandraCQLTest local[2] localhost 9160 - * + * Parameters : + * Usage: ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.CassandraCQLTest localhost 9160 */ object CassandraCQLTest { def main(args: Array[String]) { - val sc = new SparkContext(args(0), - "CQLTestApp", - System.getenv("SPARK_HOME"), - SparkContext.jarOfClass(this.getClass).toSeq) - val cHost: String = args(1) - val cPort: String = args(2) + val sparkConf = new SparkConf().setAppName("CQLTestApp") + + val sc = new SparkContext(sparkConf) + val cHost: String = args(0) + val cPort: String = args(1) val KeySpace = "retail" val InputColumnFamily = "ordercf" val OutputColumnFamily = "salecount" diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ed5d2f9e46f29..91ba364a346a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -30,7 +30,7 @@ import org.apache.cassandra.thrift._ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /* @@ -38,10 +38,10 @@ import org.apache.spark.SparkContext._ * support for Hadoop. * * To run this example, run this file with the following command params - - * + * * * So if you want to run this on localhost this will be, - * local[3] localhost 9160 + * localhost 9160 * * The example makes some assumptions: * 1. You have already created a keyspace called casDemo and it has a column family named Words @@ -54,9 +54,9 @@ import org.apache.spark.SparkContext._ object CassandraTest { def main(args: Array[String]) { - + val sparkConf = new SparkConf().setAppName("casDemo") // Get a SparkContext - val sc = new SparkContext(args(0), "casDemo") + val sc = new SparkContext(sparkConf) // Build the job configuration with ConfigHelper provided by Cassandra val job = new Job() diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala index f0dcef431b2e1..d42f63e87052e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -17,17 +17,12 @@ package org.apache.spark.examples -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} object ExceptionHandlingTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: ExceptionHandlingTest ") - System.exit(1) - } - - val sc = new SparkContext(args(0), "ExceptionHandlingTest", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("ExceptionHandlingTest") + val sc = new SparkContext(sparkConf) sc.parallelize(0 until sc.defaultParallelism).foreach { i => if (math.random > 0.75) { throw new Exception("Testing exception handling") diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index e67bb29a49405..efd91bb054981 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -19,24 +19,21 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +/** + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object GroupByTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println( - "Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("GroupBy Test") + var numMappers = if (args.length > 0) args(0).toInt else 2 + var numKVPairs = if (args.length > 1) args(1).toInt else 1000 + var valSize = if (args.length > 2) args(2).toInt else 1000 + var numReducers = if (args.length > 3) args(3).toInt else numMappers + + val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index adbd1c02fa2ea..a8c338480e6e2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -26,11 +26,9 @@ import org.apache.spark.rdd.NewHadoopRDD object HBaseTest { def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HBaseTest", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - + val sparkConf = new SparkConf().setAppName("HBaseTest") + val sc = new SparkContext(sparkConf) val conf = HBaseConfiguration.create() - // Other options for configuring scan behavior are available. More information available at // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html conf.set(TableInputFormat.INPUT_TABLE, args(1)) diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index c7a4884af10b7..331de3ad1ef53 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -21,8 +21,8 @@ import org.apache.spark._ object HdfsTest { def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HdfsTest", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("HdfsTest") + val sc = new SparkContext(sparkConf) val file = sc.textFile(args(1)) val mapped = file.map(s => s.length).cache() for (iter <- 1 to 10) { diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index f77a444ff7a9f..4c655b84fde2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -17,11 +17,13 @@ package org.apache.spark.examples -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. + * + * Usage: LogQuery [logFile] */ object LogQuery { val exampleApacheLogs = List( @@ -40,16 +42,12 @@ object LogQuery { ) def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: LogQuery [logFile]") - System.exit(1) - } - val sc = new SparkContext(args(0), "Log Query", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("Log Query") + val sc = new SparkContext(sparkConf) val dataSet = - if (args.length == 2) sc.textFile(args(1)) else sc.parallelize(exampleApacheLogs) + if (args.length == 1) sc.textFile(args(0)) else sc.parallelize(exampleApacheLogs) // scalastyle:off val apacheLogRegex = """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index c8985eae33de3..2a5c0c0defe13 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,20 +18,19 @@ package org.apache.spark.examples import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} +/** + * Usage: MultiBroadcastTest [slices] [numElem] + */ object MultiBroadcastTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: MultiBroadcastTest [] [numElem]") - System.exit(1) - } - val sc = new SparkContext(args(0), "Multi-Broadcast Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("Multi-Broadcast Test") + val sc = new SparkContext(sparkConf) - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 + val slices = if (args.length > 0) args(0).toInt else 2 + val num = if (args.length > 1) args(1).toInt else 1000000 val arr1 = new Array[Int](num) for (i <- 0 until arr1.length) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 54e8503711e30..5291ab81f459e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -19,25 +19,23 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +/** + * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] + */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SimpleSkewedGroupByTest " + - "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") - System.exit(1) - } - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - var ratio = if (args.length > 5) args(5).toInt else 5.0 + val sparkConf = new SparkConf().setAppName("SimpleSkewedGroupByTest") + var numMappers = if (args.length > 0) args(0).toInt else 2 + var numKVPairs = if (args.length > 1) args(1).toInt else 1000 + var valSize = if (args.length > 2) args(2).toInt else 1000 + var numReducers = if (args.length > 3) args(3).toInt else numMappers + var ratio = if (args.length > 4) args(4).toInt else 5.0 - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 1c5f22e1c00bb..017d4e1e5ce13 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -19,24 +19,21 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +/** + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object SkewedGroupByTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println( - "Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("GroupBy Test") + var numMappers = if (args.length > 0) args(0).toInt else 2 + var numKVPairs = if (args.length > 1) args(1).toInt else 1000 + var valSize = if (args.length > 2) args(2).toInt else 1000 + var numReducers = if (args.length > 3) args(3).toInt else numMappers + + val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 0dc726aecdd28..5cbc966bf06ca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -88,32 +88,24 @@ object SparkALS { } def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkALS [ ]") - System.exit(1) - } - - var host = "" var slices = 0 - val options = (0 to 5).map(i => if (i < args.length) Some(args(i)) else None) + val options = (0 to 4).map(i => if (i < args.length) Some(args(i)) else None) options.toArray match { - case Array(host_, m, u, f, iters, slices_) => - host = host_.get + case Array(m, u, f, iters, slices_) => M = m.getOrElse("100").toInt U = u.getOrElse("500").toInt F = f.getOrElse("10").toInt ITERATIONS = iters.getOrElse("5").toInt slices = slices_.getOrElse("2").toInt case _ => - System.err.println("Usage: SparkALS [ ]") + System.err.println("Usage: SparkALS [M] [U] [F] [iters] [slices]") System.exit(1) } printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) - - val sc = new SparkContext(host, "SparkALS", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("SparkALS") + val sc = new SparkContext(sparkConf) val R = generateR() diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 3a6f18c33ea4b..4906a696e90a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -49,20 +49,21 @@ object SparkHdfsLR { } def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: SparkHdfsLR ") + if (args.length < 2) { + System.err.println("Usage: SparkHdfsLR ") System.exit(1) } - val inputPath = args(1) + + val sparkConf = new SparkConf().setAppName("SparkHdfsLR") + val inputPath = args(0) val conf = SparkHadoopUtil.get.newConfiguration() - val sc = new SparkContext(args(0), "SparkHdfsLR", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq, Map(), + val sc = new SparkContext(sparkConf, InputFormatInfo.computePreferredLocations( Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) )) val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).cache() - val ITERATIONS = args(2).toInt + val ITERATIONS = args(1).toInt // Initialize w to a random value var w = DenseVector.fill(D){2 * rand.nextDouble - 1} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index dcae9591b0407..4d28e0aad6597 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -21,7 +21,7 @@ import java.util.Random import breeze.linalg.{Vector, DenseVector, squaredDistance} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** @@ -52,16 +52,16 @@ object SparkKMeans { } def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: SparkLocalKMeans ") - System.exit(1) + if (args.length < 3) { + System.err.println("Usage: SparkKMeans ") + System.exit(1) } - val sc = new SparkContext(args(0), "SparkLocalKMeans", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val lines = sc.textFile(args(1)) + val sparkConf = new SparkConf().setAppName("SparkKMeans") + val sc = new SparkContext(sparkConf) + val lines = sc.textFile(args(0)) val data = lines.map(parseVector _).cache() - val K = args(2).toInt - val convergeDist = args(3).toDouble + val K = args(1).toInt + val convergeDist = args(2).toDouble val kPoints = data.takeSample(withReplacement = false, K, 42).toArray var tempDist = 1.0 diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 4f74882ccbea5..99ceb3089e9fe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -27,6 +27,7 @@ import org.apache.spark._ /** * Logistic regression based classification. + * Usage: SparkLR [slices] */ object SparkLR { val N = 10000 // Number of data points @@ -47,13 +48,9 @@ object SparkLR { } def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkLR []") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLR", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val numSlices = if (args.length > 1) args(1).toInt else 2 + val sparkConf = new SparkConf().setAppName("SparkLR") + val sc = new SparkContext(sparkConf) + val numSlices = if (args.length > 0) args(0).toInt else 2 val points = sc.parallelize(generateData, numSlices).cache() // Initialize w to a random value diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index fa41c5c560943..40b36c779afd6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples import org.apache.spark.SparkContext._ -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} /** * Computes the PageRank of URLs from an input file. Input file should @@ -31,14 +31,10 @@ import org.apache.spark.SparkContext */ object SparkPageRank { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: PageRank ") - System.exit(1) - } - var iters = args(2).toInt - val ctx = new SparkContext(args(0), "PageRank", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val lines = ctx.textFile(args(1), 1) + val sparkConf = new SparkConf().setAppName("PageRank") + var iters = args(1).toInt + val ctx = new SparkContext(sparkConf) + val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => val parts = s.split("\\s+") (parts(0), parts(1)) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index d8f5720504223..9fbb0a800d735 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -24,13 +24,9 @@ import org.apache.spark._ /** Computes an approximation to pi */ object SparkPi { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkPi []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkPi", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val slices = if (args.length > 1) args(1).toInt else 2 + val conf = new SparkConf().setAppName("Spark Pi") + val spark = new SparkContext(conf) + val slices = if (args.length > 0) args(0).toInt else 2 val n = 100000 * slices val count = spark.parallelize(1 to n, slices).map { i => val x = random * 2 - 1 diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 17d983cd875db..f7f83086df3db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import scala.util.Random import scala.collection.mutable -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** @@ -42,13 +42,9 @@ object SparkTC { } def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkTC []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkTC", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val slices = if (args.length > 1) args(1).toInt else 2 + val sparkConf = new SparkConf().setAppName("SparkTC") + val spark = new SparkContext(sparkConf) + val slices = if (args.length > 0) args(0).toInt else 2 var tc = spark.parallelize(generateGraph, slices).cache() // Linear transitive closure: each round grows paths by one edge, diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 7e43c384bdb9d..22127621867e1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -51,20 +51,16 @@ object SparkTachyonHdfsLR { } def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: SparkTachyonHdfsLR ") - System.exit(1) - } - val inputPath = args(1) + val inputPath = args(0) val conf = SparkHadoopUtil.get.newConfiguration() - val sc = new SparkContext(args(0), "SparkTachyonHdfsLR", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq, Map(), + val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") + val sc = new SparkContext(sparkConf, InputFormatInfo.computePreferredLocations( Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) )) val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP) - val ITERATIONS = args(2).toInt + val ITERATIONS = args(1).toInt // Initialize w to a random value var w = DenseVector.fill(D){2 * rand.nextDouble - 1} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 93459110e4e0e..7743f7968b100 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -28,14 +28,10 @@ import org.apache.spark.storage.StorageLevel */ object SparkTachyonPi { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkTachyonPi []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkTachyonPi", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("SparkTachyonPi") + val spark = new SparkContext(sparkConf) - val slices = if (args.length > 1) args(1).toInt else 2 + val slices = if (args.length > 0) args(0).toInt else 2 val n = 100000 * slices val rdd = spark.parallelize(1 to n, slices) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala index 25bd55ca88b94..235c3bf820244 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -32,22 +32,22 @@ import scala.xml.{XML,NodeSeq} */ object WikipediaPageRank { def main(args: Array[String]) { - if (args.length < 5) { + if (args.length < 4) { System.err.println( - "Usage: WikipediaPageRank ") + "Usage: WikipediaPageRank ") System.exit(-1) } val sparkConf = new SparkConf() + sparkConf.setAppName("WikipediaPageRank") sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) val inputFile = args(0) val threshold = args(1).toDouble val numPartitions = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean + val usePartitioner = args(3).toBoolean - sparkConf.setMaster(host).setAppName("WikipediaPageRank") + sparkConf.setAppName("WikipediaPageRank") val sc = new SparkContext(sparkConf) // Parse the Wikipedia page data into a graph diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala index dee3cb6c0abae..a197dac87d6db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -30,22 +30,20 @@ import org.apache.spark.rdd.RDD object WikipediaPageRankStandalone { def main(args: Array[String]) { - if (args.length < 5) { + if (args.length < 4) { System.err.println("Usage: WikipediaPageRankStandalone " + - " ") + " ") System.exit(-1) } val sparkConf = new SparkConf() sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - val inputFile = args(0) val threshold = args(1).toDouble val numIterations = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean + val usePartitioner = args(3).toBoolean - sparkConf.setMaster(host).setAppName("WikipediaPageRankStandalone") + sparkConf.setAppName("WikipediaPageRankStandalone") val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index d58fddff2b5ec..6ef3b62dcbedc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -28,9 +28,9 @@ import org.apache.spark.graphx.lib.Analytics */ object LiveJournalPageRank { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 1) { System.err.println( - "Usage: LiveJournalPageRank \n" + + "Usage: LiveJournalPageRank \n" + " [--tol=]\n" + " The tolerance allowed at convergence (smaller => more accurate). Default is " + "0.001.\n" + @@ -44,6 +44,6 @@ object LiveJournalPageRank { System.exit(-1) } - Analytics.main(args.patch(1, List("pagerank"), 0)) + Analytics.main(args.patch(0, List("pagerank"), 0)) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index ff9254b044c24..61c460c6b1de8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.sql -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext // One method for defining the schema of an RDD is to make a case class with the desired column @@ -26,7 +26,8 @@ case class Record(key: Int, value: String) object RDDRelation { def main(args: Array[String]) { - val sc = new SparkContext("local", "RDDRelation") + val sparkConf = new SparkConf().setAppName("RDDRelation") + val sc = new SparkContext(sparkConf) val sqlContext = new SQLContext(sc) // Importing the SQL context gives access to all the SQL functions and implicit conversions. diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 66ce93a26ef42..b262fabbe0e0d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.sql.hive -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ import org.apache.spark.sql.hive.LocalHiveContext @@ -25,7 +25,8 @@ object HiveFromSpark { case class Record(key: Int, value: String) def main(args: Array[String]) { - val sc = new SparkContext("local", "HiveFromSpark") + val sparkConf = new SparkConf().setAppName("HiveFromSpark") + val sc = new SparkContext(sparkConf) // A local hive context creates an instance of the Hive Metastore in process, storing the // the warehouse data in the current directory. This location can be overridden by diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 84cf43df0f96c..e29e16a9c1b17 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -126,31 +126,30 @@ object FeederActor { /** * A sample word count program demonstrating the use of plugging in * Actor as Receiver - * Usage: ActorWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: ActorWordCount * and describe the AkkaSystem that Spark Sample feeder is running on. * * To run this example locally, you may run Feeder Actor as - * `$ ./bin/run-example org.apache.spark.examples.streaming.FeederActor 127.0.1.1 9999` + * `./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.FeederActor 127.0.1.1 9999` * and then run the example - * `./bin/run-example org.apache.spark.examples.streaming.ActorWordCount local[2] 127.0.1.1 9999` + * `./bin/spark-submit examples.jar --class org.apache.spark.examples.streaming.ActorWordCount \ + * 127.0.1.1 9999` */ object ActorWordCount { def main(args: Array[String]) { - if (args.length < 3) { + if (args.length < 2) { System.err.println( - "Usage: ActorWordCount " + - "In local mode, should be 'local[n]' with n > 1") + "Usage: ActorWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Seq(master, host, port) = args.toSeq - + val Seq(host, port) = args.toSeq + val sparkConf = new SparkConf().setAppName("ActorWordCount") // Create the context and set the batch size - val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(2)) /* * Following is the use of actorStream to plug in custom actor as receiver diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 5b2a1035fc779..38362edac27f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -17,6 +17,7 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.flume._ @@ -29,9 +30,8 @@ import org.apache.spark.util.IntParam * an Avro server on at the request host:port address and listen for requests. * Your Flume AvroSink should be pointed to this address. * - * Usage: FlumeEventCount + * Usage: FlumeEventCount * - * is a Spark master URL * is the host the Flume receiver will be started on - a receiver * creates a server and listens for flume events. * is the port the Flume receiver will listen on. @@ -40,21 +40,21 @@ object FlumeEventCount { def main(args: Array[String]) { if (args.length != 3) { System.err.println( - "Usage: FlumeEventCount ") + "Usage: FlumeEventCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Array(master, host, IntParam(port)) = args + val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) + val sparkConf = new SparkConf().setAppName("FlumeEventCount") // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, batchInterval) // Create a flume stream - val stream = FlumeUtils.createStream(ssc, host,port,StorageLevel.MEMORY_ONLY_SER_2) + val stream = FlumeUtils.createStream(ssc, host, port, StorageLevel.MEMORY_ONLY_SER_2) // Print out the count of events received from this server in each batch stream.count().map(cnt => "Received " + cnt + " flume events." ).print() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index b440956ba3137..55ac48cfb6d10 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -17,35 +17,35 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ /** * Counts words in new text files created in the given directory - * Usage: HdfsWordCount - * is the Spark master URL. + * Usage: HdfsWordCount * is the directory that Spark Streaming will use to find and read new text files. * * To run this on your local machine on directory `localdir`, run this example - * `$ ./bin/run-example org.apache.spark.examples.streaming.HdfsWordCount local[2] localdir` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.HdfsWordCount localdir` * Then create a text file in `localdir` and the words in the file will get counted. */ object HdfsWordCount { def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: HdfsWordCount ") + if (args.length < 1) { + System.err.println("Usage: HdfsWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - + val sparkConf = new SparkConf().setAppName("HdfsWordCount") // Create the context - val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(2)) // Create the FileInputDStream on the directory and use the // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) + val lines = ssc.textFileStream(args(0)) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index c3aae5af05b1c..3af806981f37a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -24,34 +24,33 @@ import kafka.producer._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.kafka._ +import org.apache.spark.SparkConf -// scalastyle:off /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: KafkaWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: KafkaWordCount * is a list of one or more zookeeper servers that make quorum * is the name of kafka consumer group * is a list of one or more kafka topics to consume from * is the number of threads the kafka consumer should use * * Example: - * `./bin/run-example org.apache.spark.examples.streaming.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1` + * `./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.KafkaWordCount local[2] zoo01,zoo02,zoo03 \ + * my-consumer-group topic1,topic2 1` */ -// scalastyle:on object KafkaWordCount { def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: KafkaWordCount ") + if (args.length < 4) { + System.err.println("Usage: KafkaWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Array(master, zkQuorum, group, topics, numThreads) = args - - val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val Array(zkQuorum, group, topics, numThreads) = args + val sparkConf = new SparkConf().setAppName("KafkaWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 47bf1e5a06439..3a10daa9ab84a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -24,6 +24,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.mqtt._ +import org.apache.spark.SparkConf /** * A simple Mqtt publisher for demonstration purposes, repeatedly publishes @@ -64,7 +65,6 @@ object MQTTPublisher { } } -// scalastyle:off /** * A sample wordcount with MqttStream stream * @@ -74,30 +74,28 @@ object MQTTPublisher { * Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/ * Example Java code for Mqtt Publisher and Subscriber can be found here * https://bitbucket.org/mkjinesh/mqttclient - * Usage: MQTTWordCount - * In local mode, should be 'local[n]' with n > 1 - * and describe where Mqtt publisher is running. + * Usage: MQTTWordCount +\ * and describe where Mqtt publisher is running. * * To run this example locally, you may run publisher as - * `$ ./bin/run-example org.apache.spark.examples.streaming.MQTTPublisher tcp://localhost:1883 foo` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.MQTTPublisher tcp://localhost:1883 foo` * and run the example as - * `$ ./bin/run-example org.apache.spark.examples.streaming.MQTTWordCount local[2] tcp://localhost:1883 foo` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.MQTTWordCount tcp://localhost:1883 foo` */ -// scalastyle:on object MQTTWordCount { def main(args: Array[String]) { - if (args.length < 3) { + if (args.length < 2) { System.err.println( - "Usage: MQTTWordCount " + - " In local mode, should be 'local[n]' with n > 1") + "Usage: MQTTWordCount ") System.exit(1) } - val Seq(master, brokerUrl, topic) = args.toSeq - - val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"), - StreamingContext.jarOfClass(this.getClass).toSeq) + val Seq(brokerUrl, topic) = args.toSeq + val sparkConf = new SparkConf().setAppName("MQTTWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.toString.split(" ")) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index acfe9a4da3596..ad7a199b2c0ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -17,41 +17,38 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.storage.StorageLevel -// scalastyle:off /** * Counts words in text encoded with UTF8 received from the network every second. * - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. + * Usage: NetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example - * `$ ./bin/run-example org.apache.spark.examples.streaming.NetworkWordCount local[2] localhost 9999` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.NetworkWordCount localhost 9999` */ -// scalastyle:on object NetworkWordCount { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: NetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - + val sparkConf = new SparkConf().setAppName("NetworkWordCount"); // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt, StorageLevel.MEMORY_ONLY_SER) + val lines = ssc.socketTextStream(args(0), args(1).toInt, StorageLevel.MEMORY_ONLY_SER) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala index f92f72f2de876..4caa90659111a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala @@ -19,6 +19,7 @@ package org.apache.spark.examples.streaming import scala.collection.mutable.SynchronizedQueue +import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ @@ -26,16 +27,11 @@ import org.apache.spark.streaming.StreamingContext._ object QueueStream { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: QueueStream ") - System.exit(1) - } StreamingExamples.setStreamingLogLevels() - + val sparkConf = new SparkConf().setAppName("QueueStream") // Create the context - val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) // Create the queue through which RDDs can be pushed to // a QueueInputDStream diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 1b0319a046433..a9aaa445bccb6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -17,6 +17,7 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.util.IntParam @@ -27,29 +28,26 @@ import org.apache.spark.util.IntParam * will only work with spark.streaming.util.RawTextSender running on all worker nodes * and with Spark using Kryo serialization (set Java property "spark.serializer" to * "org.apache.spark.serializer.KryoSerializer"). - * Usage: RawNetworkGrep - * is the Spark master URL + * Usage: RawNetworkGrep * is the number rawNetworkStreams, which should be same as number * of work nodes in the cluster * is "localhost". * is the port on which RawTextSender is running in the worker nodes. * is the Spark Streaming batch duration in milliseconds. */ - object RawNetworkGrep { def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: RawNetworkGrep ") + if (args.length != 4) { + System.err.println("Usage: RawNetworkGrep ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - + val Array(IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + val sparkConf = new SparkConf().setAppName("RawNetworkGrep") // Create the context - val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Duration(batchMillis)) val rawStreams = (1 to numStreams).map(_ => ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index b0bc31cc66ab5..ace785d9fe4c5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -17,19 +17,21 @@ package org.apache.spark.examples.streaming +import java.io.File +import java.nio.charset.Charset + +import com.google.common.io.Files + +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.util.IntParam -import java.io.File -import org.apache.spark.rdd.RDD -import com.google.common.io.Files -import java.nio.charset.Charset /** * Counts words in text encoded with UTF8 received from the network every second. * - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: NetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. directory to HDFS-compatible file system which checkpoint data * file to which the word counts will be appended @@ -44,8 +46,9 @@ import java.nio.charset.Charset * * and run the example as * - * `$ ./run-example org.apache.spark.examples.streaming.RecoverableNetworkWordCount \ - * local[2] localhost 9999 ~/checkpoint/ ~/out` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.RecoverableNetworkWordCount \ + * localhost 9999 ~/checkpoint/ ~/out` * * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if @@ -67,17 +70,16 @@ import java.nio.charset.Charset object RecoverableNetworkWordCount { - def createContext(master: String, ip: String, port: Int, outputPath: String) = { + def createContext(ip: String, port: Int, outputPath: String) = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint println("Creating new context") val outputFile = new File(outputPath) if (outputFile.exists()) outputFile.delete() - + val sparkConf = new SparkConf().setAppName("RecoverableNetworkWordCount") // Create the context with a 1 second batch size - val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') @@ -94,13 +96,12 @@ object RecoverableNetworkWordCount { } def main(args: Array[String]) { - if (args.length != 5) { + if (args.length != 4) { System.err.println("You arguments were " + args.mkString("[", ", ", "]")) System.err.println( """ - |Usage: RecoverableNetworkWordCount - | is the Spark master URL. In local mode, should be - | 'local[n]' with n > 1. and describe the TCP server that Spark + |Usage: RecoverableNetworkWordCount + | . and describe the TCP server that Spark | Streaming would connect to receive data. directory to | HDFS-compatible file system which checkpoint data file to which the | word counts will be appended @@ -111,10 +112,10 @@ object RecoverableNetworkWordCount { ) System.exit(1) } - val Array(master, ip, IntParam(port), checkpointDirectory, outputPath) = args + val Array(ip, IntParam(port), checkpointDirectory, outputPath) = args val ssc = StreamingContext.getOrCreate(checkpointDirectory, () => { - createContext(master, ip, port, outputPath) + createContext(ip, port, outputPath) }) ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 8001d56c98d86..5e1415f3cc536 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -17,28 +17,27 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// scalastyle:off + /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every * second. - * Usage: StatefulNetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: StatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example - * `$ ./bin/run-example org.apache.spark.examples.streaming.StatefulNetworkWordCount local[2] localhost 9999` + * `$ ./bin/spark-submit examples.jar + * --class org.apache.spark.examples.streaming.StatefulNetworkWordCount localhost 9999` */ -// scalastyle:on object StatefulNetworkWordCount { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: StatefulNetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") + if (args.length < 2) { + System.err.println("Usage: StatefulNetworkWordCount ") System.exit(1) } @@ -52,14 +51,14 @@ object StatefulNetworkWordCount { Some(currentCount + previousCount) } + val sparkConf = new SparkConf().setAppName("NetworkWordCumulativeCountUpdateStateByKey") // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", - Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt) + val lines = ssc.socketTextStream(args(0), args(1).toInt) val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index b12617d881787..683752ac96241 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -19,11 +19,13 @@ package org.apache.spark.examples.streaming import com.twitter.algebird._ +import org.apache.spark.SparkConf import org.apache.spark.SparkContext._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.twitter._ + // scalastyle:off /** * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute @@ -49,12 +51,6 @@ import org.apache.spark.streaming.twitter._ // scalastyle:on object TwitterAlgebirdCMS { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdCMS " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - StreamingExamples.setStreamingLogLevels() // CMS parameters @@ -65,10 +61,9 @@ object TwitterAlgebirdCMS { // K highest frequency elements to take val TOPK = 10 - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val filters = args + val sparkConf = new SparkConf().setAppName("TwitterAlgebirdCMS") + val ssc = new StreamingContext(sparkConf, Seconds(10)) val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER_2) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 22f232c72545c..62db5e663b8af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -23,6 +23,8 @@ import com.twitter.algebird.HyperLogLog._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.twitter._ +import org.apache.spark.SparkConf + // scalastyle:off /** * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute @@ -42,20 +44,14 @@ import org.apache.spark.streaming.twitter._ // scalastyle:on object TwitterAlgebirdHLL { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdHLL " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } StreamingExamples.setStreamingLogLevels() /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ val BIT_SIZE = 12 - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val filters = args + val sparkConf = new SparkConf().setAppName("TwitterAlgebirdHLL") + val ssc = new StreamingContext(sparkConf, Seconds(5)) val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index 5b58e94600a16..1ddff22cb8a42 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -21,6 +21,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} import StreamingContext._ import org.apache.spark.SparkContext._ import org.apache.spark.streaming.twitter._ +import org.apache.spark.SparkConf /** * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter @@ -30,18 +31,12 @@ import org.apache.spark.streaming.twitter._ */ object TwitterPopularTags { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterPopularTags " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } StreamingExamples.setStreamingLogLevels() - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val filters = args + val sparkConf = new SparkConf().setAppName("TwitterPopularTags") + val ssc = new StreamingContext(sparkConf, Seconds(2)) val stream = TwitterUtils.createStream(ssc, None, filters) val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index de46e5f5b10b6..7ade3f1018ee8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -28,6 +28,7 @@ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.zeromq._ import scala.language.implicitConversions +import org.apache.spark.SparkConf /** * A simple publisher for demonstration purposes, repeatedly publishes random Messages @@ -63,30 +64,28 @@ object SimpleZeroMQPublisher { * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide] * (http://www.zeromq.org/intro:get-the-software) * - * Usage: ZeroMQWordCount - * In local mode, should be 'local[n]' with n > 1 + * Usage: ZeroMQWordCount * and describe where zeroMq publisher is running. * * To run this example locally, you may run publisher as - * `$ ./bin/run-example org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` * and run the example as - * `$ ./bin/run-example org.apache.spark.examples.streaming.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.ZeroMQWordCount tcp://127.0.1.1:1234 foo` */ // scalastyle:on object ZeroMQWordCount { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ZeroMQWordCount " + - "In local mode, should be 'local[n]' with n > 1") + if (args.length < 2) { + System.err.println("Usage: ZeroMQWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Seq(master, url, topic) = args.toSeq - + val Seq(url, topic) = args.toSeq + val sparkConf = new SparkConf().setAppName("ZeroMQWordCount") // Create the context and set the batch size - val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(2)) def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala index fa533a512d53b..d901d4fe225fe 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala @@ -27,10 +27,14 @@ import org.apache.spark.graphx.PartitionStrategy._ object Analytics extends Logging { def main(args: Array[String]): Unit = { - val host = args(0) - val taskType = args(1) - val fname = args(2) - val options = args.drop(3).map { arg => + if (args.length < 2) { + System.err.println("Usage: Analytics [other options]") + System.exit(1) + } + + val taskType = args(0) + val fname = args(1) + val options = args.drop(2).map { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) case _ => throw new IllegalArgumentException("Invalid argument: " + arg) @@ -71,7 +75,7 @@ object Analytics extends Logging { println("| PageRank |") println("======================================") - val sc = new SparkContext(host, "PageRank(" + fname + ")", conf) + val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, minEdgePartitions = numEPart).cache() @@ -115,7 +119,7 @@ object Analytics extends Logging { println("| Connected Components |") println("======================================") - val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")", conf) + val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, minEdgePartitions = numEPart).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) @@ -137,7 +141,7 @@ object Analytics extends Logging { println("======================================") println("| Triangle Count |") println("======================================") - val sc = new SparkContext(host, "TriangleCount(" + fname + ")", conf) + val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")")) val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true, minEdgePartitions = numEPart).partitionBy(partitionStrategy).cache() val triangles = TriangleCount.run(graph) From c3f8b78c211df6c5adae74f37e39fb55baeff723 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 8 May 2014 12:13:07 -0700 Subject: [PATCH 13/34] [SPARK-1745] Move interrupted flag from TaskContext constructor (minor) It makes little sense to start a TaskContext that is interrupted. Indeed, I searched for all use cases of it and didn't find a single instance in which `interrupted` is true on construction. This was inspired by reviewing #640, which adds an additional `@volatile var completed` that is similar. These are not the most urgent changes, but I wanted to push them out before I forget. Author: Andrew Or Closes #675 from andrewor14/task-context and squashes the following commits: 9575e02 [Andrew Or] Add space 69455d1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into task-context c471490 [Andrew Or] Oops, removed one flag too many. Adding it back. 85311f8 [Andrew Or] Move interrupted flag from TaskContext constructor --- .../scala/org/apache/spark/TaskContext.scala | 20 ++++++++++--------- .../spark/scheduler/ShuffleMapTask.scala | 3 +-- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../org/apache/spark/CacheManagerSuite.scala | 10 +++------- .../org/apache/spark/PipedRDDSuite.scala | 4 +--- 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index fc4812753d005..51f40c339d13c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -28,13 +28,12 @@ import org.apache.spark.executor.TaskMetrics */ @DeveloperApi class TaskContext( - val stageId: Int, - val partitionId: Int, - val attemptId: Long, - val runningLocally: Boolean = false, - @volatile var interrupted: Boolean = false, - private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty -) extends Serializable { + val stageId: Int, + val partitionId: Int, + val attemptId: Long, + val runningLocally: Boolean = false, + private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends Serializable { @deprecated("use partitionId", "0.8.1") def splitId = partitionId @@ -42,7 +41,10 @@ class TaskContext( // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] - // Set to true when the task is completed, before the onCompleteCallbacks are executed. + // Whether the corresponding task has been killed. + @volatile var interrupted: Boolean = false + + // Whether the task has completed, before the onCompleteCallbacks are executed. @volatile var completed: Boolean = false /** @@ -58,6 +60,6 @@ class TaskContext( def executeOnCompleteCallbacks() { completed = true // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach{_()} + onCompleteCallbacks.reverse.foreach { _() } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 2259df0b56bad..4b0324f2b5447 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -23,7 +23,6 @@ import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap -import scala.util.Try import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics @@ -70,7 +69,7 @@ private[spark] object ShuffleMapTask { } // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { + def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val objIn = new ObjectInputStream(in) val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c3e03cea917b3..1912015827927 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -597,7 +597,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, false, new TaskMetrics()); + TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index fd5b0906e6765..4f178db40f638 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -23,7 +23,6 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.mock.EasyMockSugar import org.apache.spark.rdd.RDD -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects @@ -59,8 +58,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, - taskMetrics = TaskMetrics.empty) + val context = new TaskContext(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -72,8 +70,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, - taskMetrics = TaskMetrics.empty) + val context = new TaskContext(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,8 +83,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false, - taskMetrics = TaskMetrics.empty) + val context = new TaskContext(0, 0, 0, runningLocally = true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala index 0bb6a6b09c5b5..db56a4acdd6f5 100644 --- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala @@ -178,14 +178,12 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, - taskMetrics = TaskMetrics.empty) + val tContext = new TaskContext(0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") } else { // printenv isn't available so just pass the test - assert(true) } } From 5c5e7d5809d337ce41a7a90eb9201e12803aba48 Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Thu, 8 May 2014 13:07:30 -0700 Subject: [PATCH 14/34] Fixing typo in als.py XtY should be Xty. Author: Evan Sparks Closes #696 from etrain/patch-2 and squashes the following commits: 634cb8d [Evan Sparks] Fixing typo in als.py --- examples/src/main/python/als.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 33700ab4f8c53..01552dc1d449e 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -38,7 +38,7 @@ def update(i, vec, mat, ratings): ff = mat.shape[1] XtX = mat.T * mat - XtY = mat.T * ratings[i, :].T + Xty = mat.T * ratings[i, :].T for j in range(ff): XtX[j,j] += LAMBDA * uu From 322b1808d21143dc323493203929488d69e8878a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 8 May 2014 15:31:47 -0700 Subject: [PATCH 15/34] [SPARK-1754] [SQL] Add missing arithmetic DSL operations. Add missing arithmetic DSL operations: `unary_-`, `%`. Author: Takuya UESHIN Closes #689 from ueshin/issues/SPARK-1754 and squashes the following commits: a09ef69 [Takuya UESHIN] Add also missing ! (not) operation. f73ae2c [Takuya UESHIN] Remove redundant tests. 5b3f087 [Takuya UESHIN] Add tests relating DSL operations. e09c5b8 [Takuya UESHIN] Add missing arithmetic DSL operations. --- .../apache/spark/sql/catalyst/dsl/package.scala | 4 ++++ .../expressions/ExpressionEvaluationSuite.scala | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index dc83485df195c..78d3a1d8096af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -57,10 +57,14 @@ package object dsl { trait ImplicitOperators { def expr: Expression + def unary_- = UnaryMinus(expr) + def unary_! = Not(expr) + def + (other: Expression) = Add(expr, other) def - (other: Expression) = Subtract(expr, other) def * (other: Expression) = Multiply(expr, other) def / (other: Expression) = Divide(expr, other) + def % (other: Expression) = Remainder(expr, other) def && (other: Expression) = And(expr, other) def || (other: Expression) = Or(expr, other) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 91605d0a260e5..344d8a304fc11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -61,7 +61,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - val expr = Not(Literal(v, BooleanType)) + val expr = ! Literal(v, BooleanType) val result = expr.eval(null) if (result != answer) fail(s"$expr should not evaluate to $result, expected: $answer") } @@ -381,6 +381,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + + checkEvaluation(-c1, -1, row) + checkEvaluation(c1 + c2, 3, row) + checkEvaluation(c1 - c2, -1, row) + checkEvaluation(c1 * c2, 2, row) + checkEvaluation(c1 / c2, 0, row) + checkEvaluation(c1 % c2, 1, row) } test("BinaryComparison") { @@ -395,6 +402,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + + checkEvaluation(c1 < c2, true, row) + checkEvaluation(c1 <= c2, true, row) + checkEvaluation(c1 > c2, false, row) + checkEvaluation(c1 >= c2, false, row) + checkEvaluation(c1 === c2, false, row) + checkEvaluation(c1 !== c2, true, row) } } From d38febee46ed156b0c8ec64757db6c290e488421 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 8 May 2014 17:52:32 -0700 Subject: [PATCH 16/34] MLlib documentation fix Fixed the documentation for that `loadLibSVMData` is changed to `loadLibSVMFile`. Author: DB Tsai Closes #703 from dbtsai/dbtsai-docfix and squashes the following commits: 71dd508 [DB Tsai] loadLibSVMData is changed to loadLibSVMFile --- docs/mllib-basics.md | 8 ++++---- docs/mllib-linear-methods.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md index 704308802d65b..aa9321a547097 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-basics.md @@ -184,7 +184,7 @@ After loading, the feature indices are converted to zero-based.
-[`MLUtils.loadLibSVMData`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training +[`MLUtils.loadLibSVMFile`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training examples stored in LIBSVM format. {% highlight scala %} @@ -192,12 +192,12 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -val training: RDD[LabeledPoint] = MLUtils.loadLibSVMData(sc, "mllib/data/sample_libsvm_data.txt") +val training: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") {% endhighlight %}
-[`MLUtils.loadLibSVMData`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training +[`MLUtils.loadLibSVMFile`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training examples stored in LIBSVM format. {% highlight java %} @@ -205,7 +205,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.rdd.RDDimport; -RDD training = MLUtils.loadLibSVMData(jsc, "mllib/data/sample_libsvm_data.txt"); +RDD training = MLUtils.loadLibSVMFile(jsc, "mllib/data/sample_libsvm_data.txt"); {% endhighlight %}
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 40b7a7f80708c..eff617d8641e2 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -186,7 +186,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMData(sc, "mllib/data/sample_libsvm_data.txt") +val data = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") // Split data into training (60%) and test (40%). val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) From 910a13b3c52a6309068b4997da6df6b7d6058a1b Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 8 May 2014 17:53:22 -0700 Subject: [PATCH 17/34] [SPARK-1157][MLlib] Bug fix: lossHistory should exclude rejection steps, and remove miniBatch Getting the lossHistory from Breeze's API which already excludes the rejection steps in line search. Also, remove the miniBatch in LBFGS since those quasi-Newton methods approximate the inverse of Hessian. It doesn't make sense if the gradients are computed from a varying objective. Author: DB Tsai Closes #582 from dbtsai/dbtsai-lbfgs-bug and squashes the following commits: 9cc6cf9 [DB Tsai] Removed the miniBatch in LBFGS. 1ba6a33 [DB Tsai] Formatting the code. d72c679 [DB Tsai] Using Breeze's states to get the loss. --- .../spark/mllib/optimization/LBFGS.scala | 63 ++++++++----------- .../spark/mllib/optimization/LBFGSSuite.scala | 15 ++--- 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 969a0c5f7c953..8f187c9df5102 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -42,7 +42,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) private var convergenceTol = 1E-4 private var maxNumIterations = 100 private var regParam = 0.0 - private var miniBatchFraction = 1.0 /** * Set the number of corrections used in the LBFGS update. Default 10. @@ -57,14 +56,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } - /** - * Set fraction of data to be used for each L-BFGS iteration. Default 1.0. - */ - def setMiniBatchFraction(fraction: Double): this.type = { - this.miniBatchFraction = fraction - this - } - /** * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. @@ -110,7 +101,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) } override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { - val (weights, _) = LBFGS.runMiniBatchLBFGS( + val (weights, _) = LBFGS.runLBFGS( data, gradient, updater, @@ -118,7 +109,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) convergenceTol, maxNumIterations, regParam, - miniBatchFraction, initialWeights) weights } @@ -132,10 +122,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) @DeveloperApi object LBFGS extends Logging { /** - * Run Limited-memory BFGS (L-BFGS) in parallel using mini batches. - * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data - * in order to compute a gradient estimate. - * Sampling, and averaging the subgradients over this subset is performed using one standard + * Run Limited-memory BFGS (L-BFGS) in parallel. + * Averaging the subgradients over different partitions is performed using one standard * spark map-reduce in each iteration. * * @param data - Input data for L-BFGS. RDD of the set of data examples, each of @@ -147,14 +135,12 @@ object LBFGS extends Logging { * @param convergenceTol - The convergence tolerance of iterations for L-BFGS * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run. * @param regParam - Regularization parameter - * @param miniBatchFraction - Fraction of the input data set that should be used for - * one iteration of L-BFGS. Default value 1.0. * * @return A tuple containing two elements. The first element is a column matrix containing * weights for every feature, and the second element is an array containing the loss * computed for every iteration. */ - def runMiniBatchLBFGS( + def runLBFGS( data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, @@ -162,23 +148,33 @@ object LBFGS extends Logging { convergenceTol: Double, maxNumIterations: Int, regParam: Double, - miniBatchFraction: Double, initialWeights: Vector): (Vector, Array[Double]) = { val lossHistory = new ArrayBuffer[Double](maxNumIterations) val numExamples = data.count() - val miniBatchSize = numExamples * miniBatchFraction val costFun = - new CostFun(data, gradient, updater, regParam, miniBatchFraction, lossHistory, miniBatchSize) + new CostFun(data, gradient, updater, regParam, numExamples) val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol) - val weights = Vectors.fromBreeze( - lbfgs.minimize(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)) + val states = + lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + /** + * NOTE: lossSum and loss is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + var state = states.next() + while(states.hasNext) { + lossHistory.append(state.value) + state = states.next() + } + lossHistory.append(state.value) + val weights = Vectors.fromBreeze(state.x) - logInfo("LBFGS.runMiniBatchSGD finished. Last 10 losses %s".format( + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( lossHistory.takeRight(10).mkString(", "))) (weights, lossHistory.toArray) @@ -193,9 +189,7 @@ object LBFGS extends Logging { gradient: Gradient, updater: Updater, regParam: Double, - miniBatchFraction: Double, - lossHistory: ArrayBuffer[Double], - miniBatchSize: Double) extends DiffFunction[BDV[Double]] { + numExamples: Long) extends DiffFunction[BDV[Double]] { private var i = 0 @@ -204,8 +198,7 @@ object LBFGS extends Logging { val localData = data val localGradient = gradient - val (gradientSum, lossSum) = localData.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](weights.size), 0.0))( + val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad)) @@ -223,7 +216,7 @@ object LBFGS extends Logging { Vectors.fromBreeze(weights), Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 - val loss = lossSum / miniBatchSize + regVal + val loss = lossSum / numExamples + regVal /** * It will return the gradient part of regularization using updater. * @@ -245,14 +238,8 @@ object LBFGS extends Logging { Vectors.fromBreeze(weights), Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze - // gradientTotal = gradientSum / miniBatchSize + gradientTotal - axpy(1.0 / miniBatchSize, gradientSum, gradientTotal) - - /** - * NOTE: lossSum and loss is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - lossHistory.append(loss) + // gradientTotal = gradientSum / numExamples + gradientTotal + axpy(1.0 / numExamples, gradientSum, gradientTotal) i += 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index f33770aed30bd..6af1b502eb4dd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -59,7 +59,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { val convergenceTol = 1e-12 val maxNumIterations = 10 - val (_, loss) = LBFGS.runMiniBatchLBFGS( + val (_, loss) = LBFGS.runLBFGS( dataRDD, gradient, simpleUpdater, @@ -67,7 +67,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // Since the cost function is convex, the loss is guaranteed to be monotonically decreasing @@ -104,7 +103,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { val convergenceTol = 1e-12 val maxNumIterations = 10 - val (weightLBFGS, lossLBFGS) = LBFGS.runMiniBatchLBFGS( + val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -112,7 +111,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) val numGDIterations = 50 @@ -150,7 +148,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { val maxNumIterations = 8 var convergenceTol = 0.0 - val (_, lossLBFGS1) = LBFGS.runMiniBatchLBFGS( + val (_, lossLBFGS1) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -158,7 +156,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // Note that the first loss is computed with initial weights, @@ -166,7 +163,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { assert(lossLBFGS1.length == 9) convergenceTol = 0.1 - val (_, lossLBFGS2) = LBFGS.runMiniBatchLBFGS( + val (_, lossLBFGS2) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -174,7 +171,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed. @@ -182,7 +178,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { assert((lossLBFGS2(2) - lossLBFGS2(3)) / lossLBFGS2(2) < convergenceTol) convergenceTol = 0.01 - val (_, lossLBFGS3) = LBFGS.runMiniBatchLBFGS( + val (_, lossLBFGS3) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -190,7 +186,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // With smaller convergenceTol, it takes more steps. From 191279ce4edb940821d11a6b25cd33c8ad0af054 Mon Sep 17 00:00:00 2001 From: Funes Date: Thu, 8 May 2014 17:54:10 -0700 Subject: [PATCH 18/34] Bug fix of sparse vector conversion Fixed a small bug caused by the inconsistency of index/data array size and vector length. Author: Funes Author: funes Closes #661 from funes/bugfix and squashes the following commits: edb2b9d [funes] remove unused import 75dced3 [Funes] update test case d129a66 [Funes] Add test for sparse breeze by vector builder 64e7198 [Funes] Copy data only when necessary b85806c [Funes] Bug fix of sparse vector conversion --- .../scala/org/apache/spark/mllib/linalg/Vectors.scala | 6 +++++- .../spark/mllib/linalg/BreezeVectorConversionSuite.scala | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7cdf6bd56acd9..84d223908c1f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -136,7 +136,11 @@ object Vectors { new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one } case v: BSV[Double] => - new SparseVector(v.length, v.index, v.data) + if (v.index.length == v.used) { + new SparseVector(v.length, v.index, v.data) + } else { + new SparseVector(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used)) + } case v: BV[_] => sys.error("Unsupported Breeze vector type: " + v.getClass.getName) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index aacaa300849aa..8abdac72902c6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -55,4 +55,13 @@ class BreezeVectorConversionSuite extends FunSuite { assert(vec.indices.eq(indices), "should not copy data") assert(vec.values.eq(values), "should not copy data") } + + test("sparse breeze with partially-used arrays to vector") { + val activeSize = 3 + val breeze = new BSV[Double](indices, values, activeSize, n) + val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector] + assert(vec.size === n) + assert(vec.indices === indices.slice(0, activeSize)) + assert(vec.values === values.slice(0, activeSize)) + } } From 2fd2752e572921a9010614eb1c1238c493d34a7c Mon Sep 17 00:00:00 2001 From: Bouke van der Bijl Date: Thu, 8 May 2014 20:43:37 -0700 Subject: [PATCH 19/34] Include the sbin/spark-config.sh in spark-executor This is needed because broadcast values are broken on pyspark on Mesos, it tries to import pyspark but can't, as the PYTHONPATH is not set due to changes in ff5be9a4 https://issues.apache.org/jira/browse/SPARK-1725 Author: Bouke van der Bijl Closes #651 from bouk/include-spark-config-in-mesos-executor and squashes the following commits: b2f1295 [Bouke van der Bijl] Inline PYTHONPATH in spark-executor eedbbcc [Bouke van der Bijl] Include the sbin/spark-config.sh in spark-executor --- sbin/spark-executor | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sbin/spark-executor b/sbin/spark-executor index de5bfab563125..336549f29c9ce 100755 --- a/sbin/spark-executor +++ b/sbin/spark-executor @@ -19,5 +19,8 @@ FWDIR="$(cd `dirname $0`/..; pwd)" +export PYTHONPATH=$FWDIR/python:$PYTHONPATH +export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH + echo "Running spark-executor with framework dir = $FWDIR" exec $FWDIR/bin/spark-class org.apache.spark.executor.MesosExecutorBackend From 8b7841299439b7dc590b2f7e2339f24e8f3e19f6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 8 May 2014 20:45:29 -0700 Subject: [PATCH 20/34] [SPARK-1755] Respect SparkSubmit --name on YARN Right now, SparkSubmit ignores the `--name` flag for both yarn-client and yarn-cluster. This is a bug. In client mode, SparkSubmit treats `--name` as a [cluster config](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L170) and does not propagate this to SparkContext. In cluster mode, SparkSubmit passes this flag to `org.apache.spark.deploy.yarn.Client`, which only uses it for the [YARN ResourceManager](https://github.com/apache/spark/blob/master/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala#L80), but does not propagate this to SparkContext. This PR ensures that `spark.app.name` is always set if SparkSubmit receives the `--name` flag, which is what the usage promises. This makes it possible for applications to start a SparkContext with an empty conf `val sc = new SparkContext(new SparkConf)`, and inherit the app name from SparkSubmit. Tested both modes on a YARN cluster. Author: Andrew Or Closes #699 from andrewor14/yarn-app-name and squashes the following commits: 98f6a79 [Andrew Or] Fix tests dea932f [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-app-name c86d9ca [Andrew Or] Respect SparkSubmit --name on YARN --- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 9 +++++---- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 10 ++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e39723f38347c..16de6f7cdb100 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -160,6 +160,7 @@ object SparkSubmit { // each deploy mode; we iterate through these below val options = List[OptionAssigner]( OptionAssigner(args.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, false, sysProp = "spark.app.name"), OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, true, sysProp = "spark.driver.extraClassPath"), OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, true, @@ -167,7 +168,7 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, true, sysProp = "spark.driver.extraLibraryPath"), OptionAssigner(args.driverMemory, YARN, true, clOption = "--driver-memory"), - OptionAssigner(args.name, YARN, true, clOption = "--name"), + OptionAssigner(args.name, YARN, true, clOption = "--name", sysProp = "spark.app.name"), OptionAssigner(args.queue, YARN, true, clOption = "--queue"), OptionAssigner(args.queue, YARN, false, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, true, clOption = "--num-executors"), @@ -188,8 +189,7 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, true, clOption = "--addJars"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.files"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"), - OptionAssigner(args.jars, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.jars"), - OptionAssigner(args.name, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.app.name") + OptionAssigner(args.jars, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.jars") ) // For client mode make any added jars immediately visible on the classpath @@ -205,7 +205,8 @@ object SparkSubmit { (clusterManager & opt.clusterManager) != 0) { if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) - } else if (opt.sysProp != null) { + } + if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d7e3b22ed476e..c9edb03cdeb0f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -104,7 +104,7 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", "--driver-memory", "4g", "--queue", "thequeue", "--files", "file1.txt,file2.txt", - "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", + "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) @@ -122,7 +122,8 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { childArgsStr should include ("--num-executors 6") mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) - sysProps should have size (1) + sysProps("spark.app.name") should be ("beauty") + sysProps("SPARK_SUBMIT") should be ("true") } test("handles YARN client mode") { @@ -130,8 +131,8 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", "--driver-memory", "4g", "--queue", "thequeue", "--files", "file1.txt,file2.txt", - "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "thejar.jar", - "arg1", "arg2") + "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", + "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") @@ -140,6 +141,7 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { classpath should contain ("one.jar") classpath should contain ("two.jar") classpath should contain ("three.jar") + sysProps("spark.app.name") should be ("trill") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.executor.cores") should be ("5") sysProps("spark.yarn.queue") should be ("thequeue") From 3f779d872d8459b262b3db9e4d12b011910b6ce9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 8 May 2014 20:46:11 -0700 Subject: [PATCH 21/34] [SPARK-1631] Correctly set the Yarn app name when launching the AM. Author: Marcelo Vanzin Closes #539 from vanzin/yarn-app-name and squashes the following commits: 7d1ca4f [Marcelo Vanzin] [SPARK-1631] Correctly set the Yarn app name when launching the AM. --- .../scheduler/cluster/YarnClientSchedulerBackend.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index ce2dde0631ed9..2924189077b7d 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -35,10 +35,10 @@ private[spark] class YarnClientSchedulerBackend( private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { - if (System.getProperty(sysProp) != null) { - arrayBuf += (optionName, System.getProperty(sysProp)) - } else if (System.getenv(envVar) != null) { + if (System.getenv(envVar) != null) { arrayBuf += (optionName, System.getenv(envVar)) + } else if (sc.getConf.contains(sysProp)) { + arrayBuf += (optionName, sc.getConf.get(sysProp)) } } From 06b15baab25951d124bbe6b64906f4139e037deb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 8 May 2014 22:26:17 -0700 Subject: [PATCH 22/34] SPARK-1565 (Addendum): Replace `run-example` with `spark-submit`. Gives a nicely formatted message to the user when `run-example` is run to tell them to use `spark-submit`. Author: Patrick Wendell Closes #704 from pwendell/examples and squashes the following commits: 1996ee8 [Patrick Wendell] Feedback form Andrew 3eb7803 [Patrick Wendell] Suggestions from TD 2474668 [Patrick Wendell] SPARK-1565 (Addendum): Replace `run-example` with `spark-submit`. --- README.md | 19 +++-- bin/pyspark | 2 +- bin/run-example | 71 +++++-------------- bin/spark-class | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 +- docs/running-on-yarn.md | 2 +- make-distribution.sh | 2 + 7 files changed, 37 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index e2d1dcb5672ff..9c2e32b90f162 100644 --- a/README.md +++ b/README.md @@ -39,17 +39,22 @@ And run the following command, which should also return 1000: ## Example Programs Spark also comes with several sample programs in the `examples` directory. -To run one of them, use `./bin/run-example `. For example: +To run one of them, use `./bin/run-example [params]`. For example: - ./bin/run-example org.apache.spark.examples.SparkLR local[2] + ./bin/run-example org.apache.spark.examples.SparkLR -will run the Logistic Regression example locally on 2 CPUs. +will run the Logistic Regression example locally. -Each of the example programs prints usage help if no params are given. +You can set the MASTER environment variable when running examples to submit +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You +can also use an abbreviated class name if the class is in the `examples` +package. For instance: -All of the Spark samples take a `` parameter that is the cluster URL -to connect to. This can be a mesos:// or spark:// URL, or "local" to run -locally with one thread, or "local[N]" to run locally with N threads. + MASTER=spark://host:7077 ./bin/run-example SparkPi + +Many of the example programs print usage help if no params are given. ## Running Tests diff --git a/bin/pyspark b/bin/pyspark index f5558853e8a4e..10e35e0f1734e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -31,7 +31,7 @@ if [ ! -f "$FWDIR/RELEASE" ]; then ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null if [[ $? != 0 ]]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target" >&2 - echo "You need to build Spark with sbt/sbt assembly before running this program" >&2 + echo "You need to build Spark before running this program" >&2 exit 1 fi fi diff --git a/bin/run-example b/bin/run-example index d8a94f2e31e07..146951ac0ee56 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,28 +17,10 @@ # limitations under the License. # -cygwin=false -case "`uname`" in - CYGWIN*) cygwin=true;; -esac - SCALA_VERSION=2.10 -# Figure out where the Scala framework is installed FWDIR="$(cd `dirname $0`/..; pwd)" - -# Export this as SPARK_HOME export SPARK_HOME="$FWDIR" - -. $FWDIR/bin/load-spark-env.sh - -if [ -z "$1" ]; then - echo "Usage: run-example []" >&2 - exit 1 -fi - -# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack -# to avoid the -sources and -doc packages that are built by publish-local. EXAMPLES_DIR="$FWDIR"/examples if [ -f "$FWDIR/RELEASE" ]; then @@ -49,46 +31,29 @@ fi if [[ -z $SPARK_EXAMPLES_JAR ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" >&2 - echo "You need to build Spark with sbt/sbt assembly before running this program" >&2 + echo "You need to build Spark before running this program" >&2 exit 1 fi +EXAMPLE_MASTER=${MASTER:-"local[*]"} -# Since the examples JAR ideally shouldn't include spark-core (that dependency should be -# "provided"), also add our standard Spark classpath, built using compute-classpath.sh. -CLASSPATH=`$FWDIR/bin/compute-classpath.sh` -CLASSPATH="$SPARK_EXAMPLES_JAR:$CLASSPATH" - -if $cygwin; then - CLASSPATH=`cygpath -wp $CLASSPATH` - export SPARK_EXAMPLES_JAR=`cygpath -w $SPARK_EXAMPLES_JAR` -fi - -# Find java binary -if [ -n "${JAVA_HOME}" ]; then - RUNNER="${JAVA_HOME}/bin/java" -else - if [ `command -v java` ]; then - RUNNER="java" - else - echo "JAVA_HOME is not set" >&2 - exit 1 - fi -fi - -# Set JAVA_OPTS to be able to load native libraries and to set heap size -JAVA_OPTS="$SPARK_JAVA_OPTS" -# Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" +if [ -n "$1" ]; then + EXAMPLE_CLASS="$1" + shift +else + echo "usage: ./bin/run-example [example-args]" + echo " - set MASTER=XX to use a specific master" + echo " - can use abbreviated example class name (e.g. SparkPi, mllib.MovieLensALS)" + echo + exit -1 fi -export JAVA_OPTS -if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then - echo -n "Spark Command: " - echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" - echo "========================================" - echo +if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then + EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" fi -exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" +./bin/spark-submit \ + --master $EXAMPLE_MASTER \ + --class $EXAMPLE_CLASS \ + $SPARK_EXAMPLES_JAR \ + "$@" diff --git a/bin/spark-class b/bin/spark-class index 72f8b9bf9a495..6480ccb58d6aa 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -114,7 +114,7 @@ if [ ! -f "$FWDIR/RELEASE" ]; then jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar") if [ "$num_jars" -eq "0" ]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2 - echo "You need to build Spark with 'sbt/sbt assembly' before running this program." >&2 + echo "You need to build Spark before running this program." >&2 exit 1 fi if [ "$num_jars" -gt "1" ]; then diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a1ca612cc9a09..9d8d8044f07eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -330,9 +330,9 @@ abstract class RDD[T: ClassTag]( if (shuffle) { // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), + new ShuffledRDD[Int, T, (Int, T)](map(x => (Utils.random.nextInt(), x)), new HashPartitioner(numPartitions)), - numPartitions).keys + numPartitions).values } else { new CoalescedRDD(this, numPartitions) } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 68183ee8b4613..c563594296802 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -53,7 +53,7 @@ For example: --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 - examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + lib/spark-examples*.jar \ yarn-cluster 5 The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs. diff --git a/make-distribution.sh b/make-distribution.sh index 759e555b4b69a..1cc2844703fbb 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -40,6 +40,8 @@ # set -o pipefail +set -e + # Figure out where the Spark framework is installed FWDIR="$(cd `dirname $0`; pwd)" DISTDIR="$FWDIR/dist" From 7db47c463fefc244e9c100d4aab90451c3828261 Mon Sep 17 00:00:00 2001 From: Sandeep Date: Thu, 8 May 2014 22:30:17 -0700 Subject: [PATCH 23/34] SPARK-1775: Unneeded lock in ShuffleMapTask.deserializeInfo This was used in the past to have a cache of deserialized ShuffleMapTasks, but that's been removed, so there's no need for a lock. It slows down Spark when task descriptions are large, e.g. due to large lineage graphs or local variables. Author: Sandeep Closes #707 from techaddict/SPARK-1775 and squashes the following commits: 18d8ebf [Sandeep] SPARK-1775: Unneeded lock in ShuffleMapTask.deserializeInfo This was used in the past to have a cache of deserialized ShuffleMapTasks, but that's been removed, so there's no need for a lock. It slows down Spark when task descriptions are large, e.g. due to large lineage graphs or local variables. --- .../apache/spark/scheduler/ShuffleMapTask.scala | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 4b0324f2b5447..9ba586f7581cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -57,15 +57,13 @@ private[spark] object ShuffleMapTask { } def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { - synchronized { - val loader = Thread.currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - (rdd, dep) - } + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance() + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] + (rdd, dep) } // Since both the JarSet and FileSet have the same format this is used for both. From 4c60fd1e8c526278b7e5544d6164050d1aee0338 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 8 May 2014 22:33:06 -0700 Subject: [PATCH 24/34] MINOR: Removing dead code. Meant to do this when patching up the last merge. --- .../main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 9ba586f7581cf..ed0f56f1abdf5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -57,7 +57,6 @@ private[spark] object ShuffleMapTask { } def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { - val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) From 32868f31f88aebd580ab9329dc51a30c26af7a74 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 8 May 2014 22:34:08 -0700 Subject: [PATCH 25/34] Converted bang to ask to avoid scary warning when a block is removed Removing a block through the blockmanager gave a scary warning messages in the driver. ``` 2014-05-08 20:16:19,172 WARN BlockManagerMasterActor: Got unknown message: true 2014-05-08 20:16:19,172 WARN BlockManagerMasterActor: Got unknown message: true 2014-05-08 20:16:19,172 WARN BlockManagerMasterActor: Got unknown message: true ``` This is because the [BlockManagerSlaveActor](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala#L44) would send back an acknowledgement ("true"). But the BlockManagerMasterActor would have sent the RemoveBlock message as a send, not as ask(), so would reject the receiver "true" as a unknown message. @pwendell Author: Tathagata Das Closes #708 from tdas/bm-fix and squashes the following commits: ed4ef15 [Tathagata Das] Converted bang to ask to avoid scary warning when a block is removed. --- .../org/apache/spark/storage/BlockManagerMasterActor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 98fa0df6ec289..6aed322eeb185 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -250,7 +250,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor ! RemoveBlock(blockId) + blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) } } } From bd67551ee724fd7cce4f2e2977a862216c992ef5 Mon Sep 17 00:00:00 2001 From: witgo Date: Fri, 9 May 2014 01:51:26 -0700 Subject: [PATCH 26/34] [SPARK-1760]: fix building spark with maven documentation Author: witgo Closes #712 from witgo/building-with-maven and squashes the following commits: 215523b [witgo] fix building spark with maven documentation --- docs/building-with-maven.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index cac01ded60d94..b6dd553bbe06b 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -96,7 +96,7 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o The ScalaTest plugin also supports running only a specific test suite as follows: - $ mvn -Dhadoop.version=... -Dsuites=org.apache.spark.repl.ReplSuite test + $ mvn -Dhadoop.version=... -DwildcardSuites=org.apache.spark.repl.ReplSuite test ## Continuous Compilation ## From 59577df14c06417676a9ffdd599f5713c448e299 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 9 May 2014 14:51:34 -0700 Subject: [PATCH 27/34] SPARK-1770: Revert accidental(?) fix Looks like this change was accidentally committed here: https://github.com/apache/spark/commit/06b15baab25951d124bbe6b64906f4139e037deb but the change does not show up in the PR itself (#704). Other than not intending to go in with that PR, this also broke the test JavaAPISuite.repartition. Author: Aaron Davidson Closes #716 from aarondav/shufflerand and squashes the following commits: b1cf70b [Aaron Davidson] SPARK-1770: Revert accidental(?) fix --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9d8d8044f07eb..a1ca612cc9a09 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -330,9 +330,9 @@ abstract class RDD[T: ClassTag]( if (shuffle) { // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, (Int, T)](map(x => (Utils.random.nextInt(), x)), + new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), new HashPartitioner(numPartitions)), - numPartitions).values + numPartitions).keys } else { new CoalescedRDD(this, numPartitions) } From 2f452cbaf35dbc609ab48ec0ee5e3dd7b6b9b790 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 9 May 2014 21:50:23 -0700 Subject: [PATCH 28/34] SPARK-1686: keep schedule() calling in the main thread https://issues.apache.org/jira/browse/SPARK-1686 moved from original JIRA (by @markhamstra): In deploy.master.Master, the completeRecovery method is the last thing to be called when a standalone Master is recovering from failure. It is responsible for resetting some state, relaunching drivers, and eventually resuming its scheduling duties. There are currently four places in Master.scala where completeRecovery is called. Three of them are from within the actor's receive method, and aren't problems. The last starts from within receive when the ElectedLeader message is received, but the actual completeRecovery() call is made from the Akka scheduler. That means that it will execute on a different scheduler thread, and Master itself will end up running (i.e., schedule() ) from that Akka scheduler thread. In this PR, I added a new master message TriggerSchedule to trigger the "local" call of schedule() in the scheduler thread Author: CodingCat Closes #639 from CodingCat/SPARK-1686 and squashes the following commits: 81bb4ca [CodingCat] rename variable 69e0a2a [CodingCat] style fix 36a2ac0 [CodingCat] address Aaron's comments ec9b7bb [CodingCat] address the comments 02b37ca [CodingCat] keep schedule() calling in the main thread --- .../org/apache/spark/deploy/master/Master.scala | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fdb633bd33608..f254f5585ba25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -104,6 +104,8 @@ private[spark] class Master( var leaderElectionAgent: ActorRef = _ + private var recoveryCompletionTask: Cancellable = _ + // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. @@ -152,6 +154,10 @@ private[spark] class Master( } override def postStop() { + // prevent the CompleteRecovery message sending to restarted master + if (recoveryCompletionTask != null) { + recoveryCompletionTask.cancel() + } webUi.stop() fileSystemsUsed.foreach(_.close()) masterMetricsSystem.stop() @@ -171,10 +177,13 @@ private[spark] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() } + recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, + CompleteRecovery) } } + case CompleteRecovery => completeRecovery() + case RevokedLeadership => { logError("Leadership has been revoked -- master shutting down.") System.exit(0) @@ -465,7 +474,7 @@ private[spark] class Master( * Schedule the currently available resources among waiting apps. This method will be called * every time a new app joins or resource availability changes. */ - def schedule() { + private def schedule() { if (state != RecoveryState.ALIVE) { return } // First schedule drivers, they take strict precedence over applications @@ -485,7 +494,7 @@ private[spark] class Master( // Try to spread out each app among all the nodes, until it has all its cores for (app <- waitingApps if app.coresLeft > 0) { val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse + .filter(canUse(app, _)).sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) From 561510867a1b79beef57acf9df65c9f88481435d Mon Sep 17 00:00:00 2001 From: witgo Date: Sat, 10 May 2014 10:15:04 -0700 Subject: [PATCH 29/34] [SPARK-1644] The org.datanucleus:* should not be packaged into spark-assembly-*.jar Author: witgo Closes #688 from witgo/SPARK-1644 and squashes the following commits: 56ad6ac [witgo] review commit 87c03e4 [witgo] Merge branch 'master' of https://github.com/apache/spark into SPARK-1644 6ffa7e4 [witgo] review commit a597414 [witgo] The org.datanucleus:* should not be packaged into spark-assembly-*.jar --- assembly/pom.xml | 1 + project/SparkBuild.scala | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 7d123fb1d7f02..6c4d46aeb67bd 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -96,6 +96,7 @@ *:* + org.datanucleus:* META-INF/*.SF META-INF/*.DSA META-INF/*.RSA diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7f9746ec4acc0..27e9505ec9831 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -579,12 +579,13 @@ object SparkBuild extends Build { def extraAssemblySettings() = Seq( test in assembly := {}, mergeStrategy in assembly := { - case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard - case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard - case "log4j.properties" => MergeStrategy.discard + case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard + case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard + case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard + case "log4j.properties" => MergeStrategy.discard case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines - case "reference.conf" => MergeStrategy.concat - case _ => MergeStrategy.first + case "reference.conf" => MergeStrategy.concat + case _ => MergeStrategy.first } ) From 4d6055329846f5e09472e5f844127a5ab5880e15 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 10 May 2014 11:48:01 -0700 Subject: [PATCH 30/34] [SQL] Upgrade parquet library. I think we are hitting this issue in some perf tests: https://github.com/Parquet/parquet-mr/commit/6aed5288fd4a1398063a5a219b2ae4a9f71b02cf Credit to @aarondav ! Author: Michael Armbrust Closes #684 from marmbrus/upgradeParquet and squashes the following commits: e10a619 [Michael Armbrust] Upgrade parquet library. --- pom.xml | 2 +- project/SparkBuild.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index fb904e868cfaf..e0bff60a54cde 100644 --- a/pom.xml +++ b/pom.xml @@ -123,7 +123,7 @@ ${hadoop.version} 0.94.6 0.12.0 - 1.3.2 + 1.4.3 1.2.3 8.1.14.v20131031 0.3.6 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 27e9505ec9831..af882b3ea7beb 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -300,7 +300,7 @@ object SparkBuild extends Build { val jets3tVersion = if ("^2\\.[3-9]+".r.findFirstIn(hadoopVersion).isDefined) "0.9.0" else "0.7.1" val jettyVersion = "8.1.14.v20131031" val hiveVersion = "0.12.0" - val parquetVersion = "1.3.2" + val parquetVersion = "1.4.3" val slf4jVersion = "1.7.5" val excludeNetty = ExclusionRule(organization = "org.jboss.netty") From 8e94d2721a9d3d36697e13f8cc6567ae8aeee78b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 10 May 2014 12:03:27 -0700 Subject: [PATCH 31/34] [SPARK-1778] [SQL] Add 'limit' transformation to SchemaRDD. Add `limit` transformation to `SchemaRDD`. Author: Takuya UESHIN Closes #711 from ueshin/issues/SPARK-1778 and squashes the following commits: 33169df [Takuya UESHIN] Add 'limit' transformation to SchemaRDD. --- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 9 +++++++++ .../test/scala/org/apache/spark/sql/DslQuerySuite.scala | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 34200be3ac955..2569815ebb209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -178,6 +178,15 @@ class SchemaRDD( def orderBy(sortExprs: SortOrder*): SchemaRDD = new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + /** + * Limits the results by the given expressions. + * {{{ + * schemaRDD.limit(10) + * }}} + */ + def limit(limitExpr: Expression): SchemaRDD = + new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) + /** * Performs a grouping followed by an aggregation. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index be0f4a4c73b36..92a707ea57504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -71,6 +71,12 @@ class DslQuerySuite extends QueryTest { Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) } + test("limit") { + checkAnswer( + testData.limit(10), + testData.take(10).toSeq) + } + test("average") { checkAnswer( testData2.groupBy()(Average('a)), From 7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 10 May 2014 12:10:24 -0700 Subject: [PATCH 32/34] SPARK-1708. Add a ClassTag on Serializer and things that depend on it This pull request contains a rebased patch from @heathermiller (https://github.com/heathermiller/spark/pull/1) to add ClassTags on Serializer and types that depend on it (Broadcast and AccumulableCollection). Putting these in the public API signatures now will allow us to use Scala Pickling for serialization down the line without breaking binary compatibility. One question remaining is whether we also want them on Accumulator -- Accumulator is passed as part of a bigger Task or TaskResult object via the closure serializer so it doesn't seem super useful to add the ClassTag there. Broadcast and AccumulableCollection in contrast were being serialized directly. CC @rxin, @pwendell, @heathermiller Author: Matei Zaharia Closes #700 from mateiz/spark-1708 and squashes the following commits: 1a3d8b0 [Matei Zaharia] Use fake ClassTag in Java 3b449ed [Matei Zaharia] test fix 2209a27 [Matei Zaharia] Code style fixes 9d48830 [Matei Zaharia] Add a ClassTag on Serializer and things that depend on it --- .../scala/org/apache/spark/Accumulators.scala | 7 +-- .../scala/org/apache/spark/SparkContext.scala | 4 +- .../spark/api/java/JavaSparkContext.scala | 2 +- .../apache/spark/broadcast/Broadcast.scala | 4 +- .../spark/broadcast/BroadcastFactory.scala | 4 +- .../spark/broadcast/BroadcastManager.scala | 4 +- .../spark/broadcast/HttpBroadcast.scala | 7 ++- .../broadcast/HttpBroadcastFactory.scala | 4 +- .../spark/broadcast/TorrentBroadcast.scala | 4 +- .../broadcast/TorrentBroadcastFactory.scala | 4 +- .../org/apache/spark/rdd/CheckpointRDD.scala | 4 +- .../spark/rdd/ParallelCollectionRDD.scala | 2 +- .../apache/spark/rdd/RDDCheckpointData.scala | 2 +- .../spark/serializer/JavaSerializer.scala | 13 +++--- .../spark/serializer/KryoSerializer.scala | 12 ++--- .../apache/spark/serializer/Serializer.scala | 17 +++---- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../serializer/KryoSerializerSuite.scala | 11 ++--- .../bagel/WikipediaPageRankStandalone.scala | 12 ++--- .../spark/graphx/impl/Serializers.scala | 45 ++++++++++--------- .../apache/spark/graphx/SerializerSuite.scala | 5 ++- .../sql/execution/SparkSqlSerializer.scala | 6 ++- 22 files changed, 103 insertions(+), 72 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 6d652faae149a..cdfd338081fa2 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -21,6 +21,7 @@ import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable import scala.collection.mutable.Map +import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer @@ -164,9 +165,9 @@ trait AccumulableParam[R, T] extends Serializable { def zero(initialValue: R): R } -private[spark] -class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] - extends AccumulableParam[R,T] { +private[spark] class +GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] + extends AccumulableParam[R, T] { def addAccumulator(growable: R, elem: T): R = { growable += elem diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9d7c2c8d3d630..c639b3e15ded5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -756,7 +756,7 @@ class SparkContext(config: SparkConf) extends Logging { * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by * standard mutable collections. So you can use this with mutable Map, Set, etc. */ - def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { val param = new GrowableAccumulableParam[R,T] new Accumulable(initialValue, param) @@ -767,7 +767,7 @@ class SparkContext(config: SparkConf) extends Logging { * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = { + def broadcast[T: ClassTag](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8b95cda511643..a7cfee6d01711 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -447,7 +447,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value) + def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)(fakeClassTag) /** Shut down the SparkContext. */ def stop() { diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 738a3b1bed7f3..76956f6a345d1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -21,6 +21,8 @@ import java.io.Serializable import org.apache.spark.SparkException +import scala.reflect.ClassTag + /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable * cached on each machine rather than shipping a copy of it with tasks. They can be used, for @@ -50,7 +52,7 @@ import org.apache.spark.SparkException * @param id A unique identifier for the broadcast variable. * @tparam T Type of the data contained in the broadcast variable. */ -abstract class Broadcast[T](val id: Long) extends Serializable { +abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { /** * Flag signifying whether the broadcast variable is valid diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 8c8ce9b1691ac..a8c827030a1ef 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import scala.reflect.ClassTag + import org.apache.spark.SecurityManager import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi @@ -31,7 +33,7 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index cf62aca4d45e8..c88be6aba6901 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.broadcast import java.util.concurrent.atomic.AtomicLong +import scala.reflect.ClassTag + import org.apache.spark._ private[spark] class BroadcastManager( @@ -56,7 +58,7 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T](value_ : T, isLocal: Boolean) = { + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 29372f16f2cac..78fc286e5192c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -22,6 +22,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream} import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit +import scala.reflect.ClassTag + import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} @@ -34,7 +36,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH * (through a HTTP server running at the driver) and stored in the BlockManager of the * executor to speed up future accesses. */ -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class HttpBroadcast[T: ClassTag]( + @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def getValue = value_ @@ -173,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging { files += file.getAbsolutePath } - def read[T](id: Long): T = { + def read[T: ClassTag](id: Long): T = { logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index e3f6cdc6154dd..d5a031e2bbb59 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import scala.reflect.ClassTag + import org.apache.spark.{SecurityManager, SparkConf} /** @@ -29,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) def stop() { HttpBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 2659274c5e98e..734de37ba115d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} +import scala.reflect.ClassTag import scala.math import scala.util.Random @@ -44,7 +45,8 @@ import org.apache.spark.util.Utils * copies of the broadcast data (one per executor) as done by the * [[org.apache.spark.broadcast.HttpBroadcast]]. */ -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class TorrentBroadcast[T: ClassTag]( + @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def getValue = value_ diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index d216b58718148..1de8396a0e17f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import scala.reflect.ClassTag + import org.apache.spark.{SecurityManager, SparkConf} /** @@ -30,7 +32,7 @@ class TorrentBroadcastFactory extends BroadcastFactory { TorrentBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) def stop() { TorrentBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 888af541cf970..34c51b833025e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -84,7 +84,7 @@ private[spark] object CheckpointRDD extends Logging { "part-%05d".format(splitId) } - def writeToFile[T]( + def writeToFile[T: ClassTag]( path: String, broadcastedConf: Broadcast[SerializableWritable[Configuration]], blockSize: Int = -1 @@ -160,7 +160,7 @@ private[spark] object CheckpointRDD extends Logging { val conf = SparkHadoopUtil.get.newConfiguration() val fs = path.getFileSystem(conf) val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _) + sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 5f03d7d650a30..2425929fc73c5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -77,7 +77,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( slice = in.readInt() val ser = sfactory.newInstance() - Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) + Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject[Seq[T]]()) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 953f0555e57c5..c3b2a33fb54d0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -92,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( new SerializableWritable(rdd.context.hadoopConfiguration)) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (newRDD.partitions.size != rdd.partitions.size) { throw new SparkException( diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index e9163deaf2036..0a7e1ec539679 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -20,6 +20,8 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import scala.reflect.ClassTag + import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.ByteBufferInputStream @@ -36,7 +38,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In * But only call it every 10,000th time to avoid bloated serialization streams (when * the stream 'resets' object class descriptions have to be re-written) */ - def writeObject[T](t: T): SerializationStream = { + def writeObject[T: ClassTag](t: T): SerializationStream = { objOut.writeObject(t) if (counterReset > 0 && counter >= counterReset) { objOut.reset() @@ -46,6 +48,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In } this } + def flush() { objOut.flush() } def close() { objOut.close() } } @@ -57,12 +60,12 @@ extends DeserializationStream { Class.forName(desc.getName, false, loader) } - def readObject[T](): T = objIn.readObject().asInstanceOf[T] + def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] def close() { objIn.close() } } private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { + def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) @@ -70,13 +73,13 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize ByteBuffer.wrap(bos.toByteArray) } - def deserialize[T](bytes: ByteBuffer): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis) in.readObject().asInstanceOf[T] } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis, loader) in.readObject().asInstanceOf[T] diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index c4daec7875d26..5286f7b4c211a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -31,6 +31,8 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} +import scala.reflect.ClassTag + /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * @@ -95,7 +97,7 @@ private[spark] class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { val output = new KryoOutput(outStream) - def writeObject[T](t: T): SerializationStream = { + def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } @@ -108,7 +110,7 @@ private[spark] class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { val input = new KryoInput(inStream) - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { try { kryo.readClassAndObject(input).asInstanceOf[T] } catch { @@ -131,18 +133,18 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ lazy val output = ks.newKryoOutput() lazy val input = new KryoInput() - def serialize[T](t: T): ByteBuffer = { + def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() kryo.writeClassAndObject(output, t) ByteBuffer.wrap(output.toBytes) } - def deserialize[T](bytes: ByteBuffer): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer): T = { input.setBuffer(bytes.array) kryo.readClassAndObject(input).asInstanceOf[T] } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) input.setBuffer(bytes.array) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2c8f9b6218d6..ee26970a3d874 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -20,6 +20,8 @@ package org.apache.spark.serializer import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} import java.nio.ByteBuffer +import scala.reflect.ClassTag + import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.{ByteBufferInputStream, NextIterator} @@ -59,17 +61,17 @@ object Serializer { */ @DeveloperApi trait SerializerInstance { - def serialize[T](t: T): ByteBuffer + def serialize[T: ClassTag](t: T): ByteBuffer - def deserialize[T](bytes: ByteBuffer): T + def deserialize[T: ClassTag](bytes: ByteBuffer): T - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T def serializeStream(s: OutputStream): SerializationStream def deserializeStream(s: InputStream): DeserializationStream - def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { + def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = { // Default implementation uses serializeStream val stream = new ByteArrayOutputStream() serializeStream(stream).writeAll(iterator) @@ -85,18 +87,17 @@ trait SerializerInstance { } } - /** * :: DeveloperApi :: * A stream for writing serialized objects. */ @DeveloperApi trait SerializationStream { - def writeObject[T](t: T): SerializationStream + def writeObject[T: ClassTag](t: T): SerializationStream def flush(): Unit def close(): Unit - def writeAll[T](iter: Iterator[T]): SerializationStream = { + def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { while (iter.hasNext) { writeObject(iter.next()) } @@ -111,7 +112,7 @@ trait SerializationStream { */ @DeveloperApi trait DeserializationStream { - def readObject[T](): T + def readObject[T: ClassTag](): T def close(): Unit /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3f0ed61c5bbfb..95777fbf57d8b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -850,7 +850,7 @@ private[spark] object Utils extends Logging { /** * Clone an object using a Spark serializer. */ - def clone[T](value: T, serializer: SerializerInstance): T = { + def clone[T: ClassTag](value: T, serializer: SerializerInstance): T = { serializer.deserialize[T](serializer.serialize(value)) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 5d4673aebe9e8..cdd6b3d8feed7 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.serializer import scala.collection.mutable +import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite @@ -31,7 +32,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("basic types") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(1) @@ -61,7 +62,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("pairs") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check((1, 1)) @@ -85,7 +86,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("Scala data structures") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(List[Int]()) @@ -108,7 +109,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("ranges") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time assert(ser.serialize(t).limit < 100) @@ -129,7 +130,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("custom registrator") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala index a197dac87d6db..576a3e371b993 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -28,6 +28,8 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import scala.reflect.ClassTag + object WikipediaPageRankStandalone { def main(args: Array[String]) { if (args.length < 4) { @@ -143,15 +145,15 @@ class WPRSerializer extends org.apache.spark.serializer.Serializer { } class WPRSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { + def serialize[T: ClassTag](t: T): ByteBuffer = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: ByteBuffer): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer): T = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { throw new UnsupportedOperationException() } @@ -167,7 +169,7 @@ class WPRSerializerInstance extends SerializerInstance { class WPRSerializationStream(os: OutputStream) extends SerializationStream { val dos = new DataOutputStream(os) - def writeObject[T](t: T): SerializationStream = t match { + def writeObject[T: ClassTag](t: T): SerializationStream = t match { case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { case links: Array[String] => { dos.writeInt(0) // links @@ -200,7 +202,7 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream { class WPRDeserializationStream(is: InputStream) extends DeserializationStream { val dis = new DataInputStream(is) - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { val typeId = dis.readInt() typeId match { case 0 => { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index 2f0531ee5f379..1de42eeca1f00 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -17,20 +17,22 @@ package org.apache.spark.graphx.impl +import scala.language.existentials + import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer +import scala.reflect.ClassTag + import org.apache.spark.graphx._ import org.apache.spark.serializer._ -import scala.language.existentials - private[graphx] class VertexIdMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, _)] writeVarLong(msg._1, optimizePositive = false) this @@ -38,7 +40,7 @@ class VertexIdMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { (readVarLong(optimizePositive = false), null).asInstanceOf[T] } } @@ -51,7 +53,7 @@ class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[VertexBroadcastMsg[Int]] writeVarLong(msg.vid, optimizePositive = false) writeInt(msg.data) @@ -60,7 +62,7 @@ class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readInt() new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T] @@ -75,7 +77,7 @@ class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[VertexBroadcastMsg[Long]] writeVarLong(msg.vid, optimizePositive = false) writeLong(msg.data) @@ -84,7 +86,7 @@ class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readLong() new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T] @@ -99,7 +101,7 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[VertexBroadcastMsg[Double]] writeVarLong(msg.vid, optimizePositive = false) writeDouble(msg.data) @@ -108,7 +110,7 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readDouble() new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T] @@ -123,7 +125,7 @@ class IntAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, Int)] writeVarLong(msg._1, optimizePositive = false) writeUnsignedVarInt(msg._2) @@ -132,7 +134,7 @@ class IntAggMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readUnsignedVarInt() (a, b).asInstanceOf[T] @@ -147,7 +149,7 @@ class LongAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, Long)] writeVarLong(msg._1, optimizePositive = false) writeVarLong(msg._2, optimizePositive = true) @@ -156,7 +158,7 @@ class LongAggMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readVarLong(optimizePositive = true) (a, b).asInstanceOf[T] @@ -171,7 +173,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, Double)] writeVarLong(msg._1, optimizePositive = false) writeDouble(msg._2) @@ -180,7 +182,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readDouble() (a, b).asInstanceOf[T] @@ -196,7 +198,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable { private[graphx] abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream { // The implementation should override this one. - def writeObject[T](t: T): SerializationStream + def writeObject[T: ClassTag](t: T): SerializationStream def writeInt(v: Int) { s.write(v >> 24) @@ -309,7 +311,7 @@ abstract class ShuffleSerializationStream(s: OutputStream) extends Serialization private[graphx] abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream { // The implementation should override this one. - def readObject[T](): T + def readObject[T: ClassTag](): T def readInt(): Int = { val first = s.read() @@ -398,11 +400,12 @@ abstract class ShuffleDeserializationStream(s: InputStream) extends Deserializat private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance { - override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException - override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = throw new UnsupportedOperationException // The implementation should override the following two. diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala index 73438d9535962..91caa6b605a1e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream} import scala.util.Random +import scala.reflect.ClassTag import org.scalatest.FunSuite @@ -164,7 +165,7 @@ class SerializerSuite extends FunSuite with LocalSparkContext { def testVarLongEncoding(v: Long, optimizePositive: Boolean) { val bout = new ByteArrayOutputStream val stream = new ShuffleSerializationStream(bout) { - def writeObject[T](t: T): SerializationStream = { + def writeObject[T: ClassTag](t: T): SerializationStream = { writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive) this } @@ -173,7 +174,7 @@ class SerializerSuite extends FunSuite with LocalSparkContext { val bin = new ByteArrayInputStream(bout.toByteArray) val dstream = new ShuffleDeserializationStream(bin) { - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { readVarLong(optimizePositive).asInstanceOf[T] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 5067c14ddffeb..1c6e29b3cdee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import scala.reflect.ClassTag + import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} @@ -59,11 +61,11 @@ private[sql] object SparkSqlSerializer { new KryoSerializer(sparkConf) } - def serialize[T](o: T): Array[Byte] = { + def serialize[T: ClassTag](o: T): Array[Byte] = { ser.newInstance().serialize(o).array() } - def deserialize[T](bytes: Array[Byte]): T = { + def deserialize[T: ClassTag](bytes: Array[Byte]): T = { ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) } } From c05d11bb307eaba40c5669da2d374c28debaa55a Mon Sep 17 00:00:00 2001 From: Andy Konwinski Date: Sat, 10 May 2014 12:46:51 -0700 Subject: [PATCH 33/34] fix broken in link in python docs Author: Andy Konwinski Closes #650 from andyk/python-docs-link-fix and squashes the following commits: a1f9d51 [Andy Konwinski] fix broken in link in python docs --- docs/python-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 6813963bb080c..39fb5f0c99ca3 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -45,7 +45,7 @@ errors = logData.filter(is_error) PySpark will automatically ship these functions to executors, along with any objects that they reference. Instances of classes will be serialized and shipped to executors by PySpark, but classes themselves cannot be automatically distributed to executors. -The [Standalone Use](#standalone-use) section describes how to ship code dependencies to executors. +The [Standalone Use](#standalone-programs) section describes how to ship code dependencies to executors. In addition, PySpark fully supports interactive use---simply run `./bin/pyspark` to launch an interactive shell. From 3776f2f283842543ff766398292532c6e94221cc Mon Sep 17 00:00:00 2001 From: Bouke van der Bijl Date: Sat, 10 May 2014 13:02:13 -0700 Subject: [PATCH 34/34] Add Python includes to path before depickling broadcast values This fixes https://issues.apache.org/jira/browse/SPARK-1731 by adding the Python includes to the PYTHONPATH before depickling the broadcast values @airhorns Author: Bouke van der Bijl Closes #656 from bouk/python-includes-before-broadcast and squashes the following commits: 7b0dfe4 [Bouke van der Bijl] Add Python includes to path before depickling broadcast values --- .../org/apache/spark/api/python/PythonRDD.scala | 10 +++++----- python/pyspark/worker.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fecd9762f3f60..388b838d78bba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -179,6 +179,11 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(split.index) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.length) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -186,11 +191,6 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(broadcast.value.length) dataOut.write(broadcast.value) } - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } dataOut.flush() // Serialized command: dataOut.writeInt(command.length) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4c214ef359685..f43210c6c0301 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -56,13 +56,6 @@ def main(infile, outfile): SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True - # fetch names and values of broadcast variables - num_broadcast_variables = read_int(infile) - for _ in range(num_broadcast_variables): - bid = read_long(infile) - value = pickleSer._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) - # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) @@ -70,6 +63,13 @@ def main(infile, outfile): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) + # fetch names and values of broadcast variables + num_broadcast_variables = read_int(infile) + for _ in range(num_broadcast_variables): + bid = read_long(infile) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time()