}
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
@@ -160,42 +160,45 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
}
totalExecutionTime - metrics.get.executorRunTime
}
- val schedulerDelayQuantiles = "Scheduler delay" +:
+ val schedulerDelayTitle =
Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
}
val executorTable = new ExecutorTable(stageId, parent)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index a9ac6d5bee9c9..fd8d0b5cdde00 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable.HashMap
import scala.xml.Node
import org.apache.spark.scheduler.{StageInfo, TaskInfo}
-import org.apache.spark.ui.UIUtils
+import org.apache.spark.ui.{ToolTips, UIUtils}
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
@@ -43,9 +43,16 @@ private[ui] class StageTableBase(
Submitted
Duration
Tasks: Succeeded/Total
-
Input
-
Shuffle Read
-
Shuffle Write
+
Input
+
Shuffle Read
+
+
+
+ Shuffle Write
+
+
}
def toNodeSeq: Seq[Node] = {
@@ -82,7 +89,8 @@ private[ui] class StageTableBase(
// scalastyle:off
val killLink = if (killEnabled) {
- (kill)
+ (kill)
}
// scalastyle:on
@@ -102,7 +110,7 @@ private[ui] class StageTableBase(
listener.stageIdToDescription.get(s.stageId)
.map(d =>
{d}
{nameLink} {killLink}
)
- .getOrElse(
{killLink} {nameLink} {details}
)
+ .getOrElse(
{nameLink} {killLink} {details}
)
}
protected def stageRow(s: StageInfo): Seq[Node] = {
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 070e974657860..c70e22cf09433 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -177,6 +177,31 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa")))
}
+ test("object files of classes from a JAR") {
+ val original = Thread.currentThread().getContextClassLoader
+ val className = "FileSuiteObjectFileTest"
+ val jar = TestUtils.createJarWithClasses(Seq(className))
+ val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
+ Thread.currentThread().setContextClassLoader(loader)
+ try {
+ sc = new SparkContext("local", "test")
+ val objs = sc.makeRDD(1 to 3).map { x =>
+ val loader = Thread.currentThread().getContextClassLoader
+ Class.forName(className, true, loader).newInstance()
+ }
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ objs.saveAsObjectFile(outputDir)
+ // Try reading the output back as an object file
+ val ct = reflect.ClassTag[Any](Class.forName(className, true, loader))
+ val output = sc.objectFile[Any](outputDir)
+ assert(output.collect().size === 3)
+ assert(output.collect().head.getClass.getName === className)
+ }
+ finally {
+ Thread.currentThread().setContextClassLoader(original)
+ }
+ }
+
test("write SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
sc = new SparkContext("local", "test")
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
similarity index 98%
rename from core/src/test/scala/org/apache/spark/BroadcastSuite.scala
rename to core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index c9936256a5b95..7c3d0208b195a 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -15,14 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.broadcast
+import org.apache.spark.storage.{BroadcastBlockId, _}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.scalatest.FunSuite
-import org.apache.spark.storage._
-import org.apache.spark.broadcast.{Broadcast, HttpBroadcast}
-import org.apache.spark.storage.BroadcastBlockId
-
class BroadcastSuite extends FunSuite with LocalSparkContext {
private val httpConf = broadcastConf("HttpBroadcastFactory")
diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
similarity index 97%
rename from core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
rename to core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
index df6b2604c8d8a..415ad8c432c12 100644
--- a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
@@ -15,15 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark
-
-import org.scalatest.FunSuite
+package org.apache.spark.network
import java.nio._
-import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId}
-import scala.concurrent.Await
-import scala.concurrent.TimeoutException
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.scalatest.FunSuite
+
+import scala.concurrent.{Await, TimeoutException}
import scala.concurrent.duration._
import scala.language.postfixOps
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
similarity index 95%
rename from core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
rename to core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index db56a4acdd6f5..be972c5e97a7e 100644
--- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -15,25 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.rdd
import java.io.File
-import org.scalatest.FunSuite
-
-import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition}
-import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat}
+import org.apache.spark._
+import org.scalatest.FunSuite
import scala.collection.Map
import scala.language.postfixOps
import scala.sys.process._
import scala.util.Try
-import org.apache.hadoop.io.{Text, LongWritable}
-
-import org.apache.spark.executor.TaskMetrics
-
class PipedRDDSuite extends FunSuite with SharedSparkContext {
test("basic pipe") {
diff --git a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala
similarity index 95%
rename from core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala
rename to core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala
index 4f87fd8654c4a..72596e86865b2 100644
--- a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala
@@ -15,8 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.rdd
+import org.apache.spark.SharedSparkContext
import org.scalatest.FunSuite
object ZippedPartitionsSuite {
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
similarity index 99%
rename from core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
rename to core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
index 4ab870e751778..c4765e53de17b 100644
--- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala
@@ -15,14 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark
-
-import org.scalatest.FunSuite
+package org.apache.spark.util
import akka.actor._
+import org.apache.spark._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.AkkaUtils
+import org.scalatest.FunSuite
+
import scala.concurrent.Await
/**
diff --git a/mllib/data/als/test.data b/data/mllib/als/test.data
similarity index 100%
rename from mllib/data/als/test.data
rename to data/mllib/als/test.data
diff --git a/data/kmeans_data.txt b/data/mllib/kmeans_data.txt
similarity index 100%
rename from data/kmeans_data.txt
rename to data/mllib/kmeans_data.txt
diff --git a/mllib/data/lr-data/random.data b/data/mllib/lr-data/random.data
similarity index 100%
rename from mllib/data/lr-data/random.data
rename to data/mllib/lr-data/random.data
diff --git a/data/lr_data.txt b/data/mllib/lr_data.txt
similarity index 100%
rename from data/lr_data.txt
rename to data/mllib/lr_data.txt
diff --git a/data/pagerank_data.txt b/data/mllib/pagerank_data.txt
similarity index 100%
rename from data/pagerank_data.txt
rename to data/mllib/pagerank_data.txt
diff --git a/mllib/data/ridge-data/lpsa.data b/data/mllib/ridge-data/lpsa.data
similarity index 100%
rename from mllib/data/ridge-data/lpsa.data
rename to data/mllib/ridge-data/lpsa.data
diff --git a/mllib/data/sample_libsvm_data.txt b/data/mllib/sample_libsvm_data.txt
similarity index 100%
rename from mllib/data/sample_libsvm_data.txt
rename to data/mllib/sample_libsvm_data.txt
diff --git a/mllib/data/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt
similarity index 100%
rename from mllib/data/sample_naive_bayes_data.txt
rename to data/mllib/sample_naive_bayes_data.txt
diff --git a/mllib/data/sample_svm_data.txt b/data/mllib/sample_svm_data.txt
similarity index 100%
rename from mllib/data/sample_svm_data.txt
rename to data/mllib/sample_svm_data.txt
diff --git a/mllib/data/sample_tree_data.csv b/data/mllib/sample_tree_data.csv
similarity index 100%
rename from mllib/data/sample_tree_data.csv
rename to data/mllib/sample_tree_data.csv
diff --git a/dev/run-tests b/dev/run-tests
index d9df020f7563c..edd17b53b3d8c 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -66,10 +66,10 @@ echo "========================================================================="
# (either resolution or compilation) prompts the user for input either q, r,
# etc to quit or retry. This echo is there to make it not block.
if [ -n "$_RUN_SQL_TESTS" ]; then
- echo -e "q\n" | SPARK_HIVE=true sbt/sbt clean assembly test | \
+ echo -e "q\n" | SPARK_HIVE=true sbt/sbt clean package assembly/assembly test | \
grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
else
- echo -e "q\n" | sbt/sbt clean assembly test | \
+ echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \
grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
fi
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
new file mode 100755
index 0000000000000..8dda671e976ce
--- /dev/null
+++ b/dev/run-tests-jenkins
@@ -0,0 +1,85 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Wrapper script that runs the Spark tests then reports QA results
+# to github via its API.
+
+# Go to the Spark project root directory
+FWDIR="$(cd `dirname $0`/..; pwd)"
+cd $FWDIR
+
+COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments"
+
+function post_message {
+ message=$1
+ data="{\"body\": \"$message\"}"
+ echo "Attempting to post to Github:"
+ echo "$data"
+
+ curl -D- -u x-oauth-basic:$GITHUB_OAUTH_KEY -X POST --data "$data" -H \
+ "Content-Type: application/json" \
+ $COMMENTS_URL | head -n 8
+}
+
+start_message="QA tests have started for PR $ghprbPullId."
+if [ "$sha1" == "$ghprbActualCommit" ]; then
+ start_message="$start_message This patch DID NOT merge cleanly! "
+else
+ start_message="$start_message This patch merges cleanly. "
+fi
+start_message="$start_message View progress: "
+start_message="$start_message${BUILD_URL}consoleFull"
+
+post_message "$start_message"
+
+./dev/run-tests
+test_result="$?"
+
+result_message="QA results for PR $ghprbPullId: "
+
+if [ "$test_result" -eq "0" ]; then
+ result_message="$result_message- This patch PASSES unit tests. "
+else
+ result_message="$result_message- This patch FAILED unit tests. "
+fi
+
+if [ "$sha1" != "$ghprbActualCommit" ]; then
+ result_message="$result_message- This patch merges cleanly "
+ non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ")
+ new_public_classes=$(git diff master $non_test_files \
+ | grep -e "trait " -e "class " \
+ | grep -e "{" -e "(" \
+ | grep -v -e \@\@ -e private \
+ | grep \+ \
+ | sed "s/\+ *//" \
+ | tr "\n" "~" \
+ | sed "s/~/ /g")
+ if [ "$new_public_classes" == "" ]; then
+ result_message="$result_message- This patch adds no public classes "
+ else
+ result_message="$result_message- This patch adds the following public classes (experimental): "
+ result_message="$result_message$new_public_classes"
+ fi
+fi
+result_message="${result_message} For more information see test ouptut:"
+result_message="${result_message} ${BUILD_URL}consoleFull"
+
+post_message "$result_message"
+
+exit $test_result
diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md
index b280df0c8eeb8..7e55131754a3f 100644
--- a/docs/bagel-programming-guide.md
+++ b/docs/bagel-programming-guide.md
@@ -46,7 +46,7 @@ import org.apache.spark.bagel.Bagel._
Next, we load a sample graph from a text file as a distributed dataset and package it into `PRVertex` objects. We also cache the distributed dataset because Bagel will use it multiple times and we'd like to avoid recomputing it.
{% highlight scala %}
-val input = sc.textFile("data/pagerank_data.txt")
+val input = sc.textFile("data/mllib/pagerank_data.txt")
val numVerts = input.count()
diff --git a/docs/configuration.md b/docs/configuration.md
index b84104cc7e653..07aa4c035446b 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -699,6 +699,25 @@ Apart from these, the following properties are also available, and may be useful
(in milliseconds)
+
+
spark.scheduler.minRegisteredExecutorsRatio
+
0
+
+ The minimum ratio of registered executors (registered executors / total expected executors)
+ to wait for before scheduling begins. Specified as a double between 0 and 1.
+ Regardless of whether the minimum ratio of executors has been reached,
+ the maximum amount of time it will wait before scheduling begins is controlled by config
+ spark.scheduler.maxRegisteredExecutorsWaitingTime
+
+
+
+
spark.scheduler.maxRegisteredExecutorsWaitingTime
+
30000
+
+ Maximum amount of time to wait for executors to register before scheduling begins
+ (in milliseconds).
+
+
#### Security
@@ -773,6 +792,15 @@ Apart from these, the following properties are also available, and may be useful
into blocks of data before storing them in Spark.
+
+
spark.streaming.receiver.maxRate
+
infinite
+
+ Maximum rate (per second) at which each receiver will push data into blocks. Effectively,
+ each stream will consume at most this number of records per second.
+ Setting this configuration to 0 or a negative number will put no limit on the rate.
+
+
spark.streaming.unpersist
true
diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md
index 5796e16e8f99c..f9585251fafac 100644
--- a/docs/mllib-basics.md
+++ b/docs/mllib-basics.md
@@ -193,7 +193,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-val examples: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt")
+val examples: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
{% endhighlight %}
@@ -207,7 +207,7 @@ import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.api.java.JavaRDD;
JavaRDD examples =
- MLUtils.loadLibSVMFile(jsc.sc(), "mllib/data/sample_libsvm_data.txt").toJavaRDD();
+ MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD();
{% endhighlight %}
@@ -218,7 +218,7 @@ examples stored in LIBSVM format.
{% highlight python %}
from pyspark.mllib.util import MLUtils
-examples = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt")
+examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
{% endhighlight %}
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 429cdf8d40cec..c76ac010d3f81 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -51,7 +51,7 @@ import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
-val data = sc.textFile("data/kmeans_data.txt")
+val data = sc.textFile("data/mllib/kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
// Cluster the data into two classes using KMeans
@@ -86,7 +86,7 @@ from numpy import array
from math import sqrt
# Load and parse the data
-data = sc.textFile("data/kmeans_data.txt")
+data = sc.textFile("data/mllib/kmeans_data.txt")
parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')]))
# Build the model (cluster the data)
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index d51002f015670..5cd71738722a9 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -58,7 +58,7 @@ import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.Rating
// Load and parse the data
-val data = sc.textFile("mllib/data/als/test.data")
+val data = sc.textFile("data/mllib/als/test.data")
val ratings = data.map(_.split(',') match { case Array(user, item, rate) =>
Rating(user.toInt, item.toInt, rate.toDouble)
})
@@ -112,7 +112,7 @@ from pyspark.mllib.recommendation import ALS
from numpy import array
# Load and parse the data
-data = sc.textFile("mllib/data/als/test.data")
+data = sc.textFile("data/mllib/als/test.data")
ratings = data.map(lambda line: array([float(x) for x in line.split(',')]))
# Build the recommendation model using Alternating Least Squares
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 3002a66a4fdb3..9cd768599e529 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -122,7 +122,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Gini
// Load and parse the data file
-val data = sc.textFile("mllib/data/sample_tree_data.csv")
+val data = sc.textFile("data/mllib/sample_tree_data.csv")
val parsedData = data.map { line =>
val parts = line.split(',').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
@@ -161,7 +161,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Variance
// Load and parse the data file
-val data = sc.textFile("mllib/data/sample_tree_data.csv")
+val data = sc.textFile("data/mllib/sample_tree_data.csv")
val parsedData = data.map { line =>
val parts = line.split(',').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 4dfbebbcd04b7..b4d22e0df5a85 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -187,7 +187,7 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
// Load training data in LIBSVM format.
-val data = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt")
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split data into training (60%) and test (40%).
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
@@ -259,7 +259,7 @@ def parsePoint(line):
values = [float(x) for x in line.split(' ')]
return LabeledPoint(values[0], values[1:])
-data = sc.textFile("mllib/data/sample_svm_data.txt")
+data = sc.textFile("data/mllib/sample_svm_data.txt")
parsedData = data.map(parsePoint)
# Build the model
@@ -309,7 +309,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
-val data = sc.textFile("mllib/data/ridge-data/lpsa.data")
+val data = sc.textFile("data/mllib/ridge-data/lpsa.data")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
@@ -356,7 +356,7 @@ def parsePoint(line):
values = [float(x) for x in line.replace(',', ' ').split(' ')]
return LabeledPoint(values[0], values[1:])
-data = sc.textFile("mllib/data/ridge-data/lpsa.data")
+data = sc.textFile("data/mllib/ridge-data/lpsa.data")
parsedData = data.map(parsePoint)
# Build the model
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 1d1d7dcf6ffcb..b1650c83c98b9 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -40,7 +40,7 @@ import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
-val data = sc.textFile("mllib/data/sample_naive_bayes_data.txt")
+val data = sc.textFile("data/mllib/sample_naive_bayes_data.txt")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md
index ae9ede58e8e60..651958c7812f2 100644
--- a/docs/mllib-optimization.md
+++ b/docs/mllib-optimization.md
@@ -214,7 +214,7 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.classification.LogisticRegressionModel
-val data = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt")
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
val numFeatures = data.take(1)(0).features.size
// Split data into training (60%) and test (40%).
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index f5c2bfb697c81..44775ea479ece 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -428,11 +428,11 @@ def launch_cluster(conn, opts, cluster_name):
for master in master_nodes:
master.add_tag(
key='Name',
- value='spark-{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
+ value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
for slave in slave_nodes:
slave.add_tag(
key='Name',
- value='spark-{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
+ value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
# Return all the instances
return (master_nodes, slave_nodes)
@@ -699,6 +699,7 @@ def ssh(host, opts, command):
time.sleep(30)
tries = tries + 1
+
# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990)
def _check_output(*popenargs, **kwargs):
if 'stdout' in kwargs:
diff --git a/examples/pom.xml b/examples/pom.xml
index 4f6d7fdb87d47..bd1c387c2eb91 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-examples_2.10
+
+ examples
+ jarSpark Project Exampleshttp://spark.apache.org/
diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala
index 4893b017ed819..822673347bdce 100644
--- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala
@@ -31,12 +31,12 @@ object HBaseTest {
val conf = HBaseConfiguration.create()
// Other options for configuring scan behavior are available. More information available at
// http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html
- conf.set(TableInputFormat.INPUT_TABLE, args(1))
+ conf.set(TableInputFormat.INPUT_TABLE, args(0))
// Initialize hBase table if necessary
val admin = new HBaseAdmin(conf)
- if(!admin.isTableAvailable(args(1))) {
- val tableDesc = new HTableDescriptor(args(1))
+ if (!admin.isTableAvailable(args(0))) {
+ val tableDesc = new HTableDescriptor(args(0))
admin.createTable(tableDesc)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
index 331de3ad1ef53..ed2b38e2ca6f8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala
@@ -19,16 +19,22 @@ package org.apache.spark.examples
import org.apache.spark._
+
object HdfsTest {
+
+ /** Usage: HdfsTest [file] */
def main(args: Array[String]) {
+ if (args.length < 1) {
+ System.err.println("Usage: HdfsTest ")
+ System.exit(1)
+ }
val sparkConf = new SparkConf().setAppName("HdfsTest")
val sc = new SparkContext(sparkConf)
- val file = sc.textFile(args(1))
+ val file = sc.textFile(args(0))
val mapped = file.map(s => s.length).cache()
for (iter <- 1 to 10) {
val start = System.currentTimeMillis()
for (x <- mapped) { x + 2 }
- // println("Processing: " + x)
val end = System.currentTimeMillis()
println("Iteration " + iter + " took " + (end-start) + " ms")
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
index 4d28e0aad6597..79cfedf332436 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
@@ -17,8 +17,6 @@
package org.apache.spark.examples
-import java.util.Random
-
import breeze.linalg.{Vector, DenseVector, squaredDistance}
import org.apache.spark.{SparkConf, SparkContext}
@@ -28,15 +26,12 @@ import org.apache.spark.SparkContext._
* K-means clustering.
*/
object SparkKMeans {
- val R = 1000 // Scaling factor
- val rand = new Random(42)
def parseVector(line: String): Vector[Double] = {
DenseVector(line.split(' ').map(_.toDouble))
}
def closestPoint(p: Vector[Double], centers: Array[Vector[Double]]): Int = {
- var index = 0
var bestIndex = 0
var closest = Double.PositiveInfinity
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
index 40b36c779afd6..4c7e006da0618 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
@@ -31,8 +31,12 @@ import org.apache.spark.{SparkConf, SparkContext}
*/
object SparkPageRank {
def main(args: Array[String]) {
+ if (args.length < 1) {
+ System.err.println("Usage: SparkPageRank ")
+ System.exit(1)
+ }
val sparkConf = new SparkConf().setAppName("PageRank")
- var iters = args(1).toInt
+ val iters = if (args.length > 0) args(1).toInt else 10
val ctx = new SparkContext(sparkConf)
val lines = ctx.textFile(args(0), 1)
val links = lines.map{ s =>
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index c1f581967777b..61a6aff543aed 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-streaming-flume_2.10
+
+ streaming-flume
+ jarSpark Project External Flumehttp://spark.apache.org/
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index ed35e34ad45ab..07ae88febf916 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.flume
import java.net.InetSocketAddress
import java.io.{ObjectInput, ObjectOutput, Externalizable}
import java.nio.ByteBuffer
+import java.util.concurrent.Executors
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -29,24 +30,32 @@ import org.apache.flume.source.avro.AvroFlumeEvent
import org.apache.flume.source.avro.Status
import org.apache.avro.ipc.specific.SpecificResponder
import org.apache.avro.ipc.NettyServer
-
+import org.apache.spark.Logging
import org.apache.spark.util.Utils
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream._
-import org.apache.spark.Logging
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.receiver.Receiver
+import org.jboss.netty.channel.ChannelPipelineFactory
+import org.jboss.netty.channel.Channels
+import org.jboss.netty.channel.ChannelPipeline
+import org.jboss.netty.channel.ChannelFactory
+import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory
+import org.jboss.netty.handler.codec.compression._
+import org.jboss.netty.handler.execution.ExecutionHandler
+
private[streaming]
class FlumeInputDStream[T: ClassTag](
@transient ssc_ : StreamingContext,
host: String,
port: Int,
- storageLevel: StorageLevel
+ storageLevel: StorageLevel,
+ enableDecompression: Boolean
) extends ReceiverInputDStream[SparkFlumeEvent](ssc_) {
override def getReceiver(): Receiver[SparkFlumeEvent] = {
- new FlumeReceiver(host, port, storageLevel)
+ new FlumeReceiver(host, port, storageLevel, enableDecompression)
}
}
@@ -134,22 +143,71 @@ private[streaming]
class FlumeReceiver(
host: String,
port: Int,
- storageLevel: StorageLevel
+ storageLevel: StorageLevel,
+ enableDecompression: Boolean
) extends Receiver[SparkFlumeEvent](storageLevel) with Logging {
lazy val responder = new SpecificResponder(
classOf[AvroSourceProtocol], new FlumeEventServer(this))
- lazy val server = new NettyServer(responder, new InetSocketAddress(host, port))
+ var server: NettyServer = null
+
+ private def initServer() = {
+ if (enableDecompression) {
+ val channelFactory = new NioServerSocketChannelFactory
+ (Executors.newCachedThreadPool(), Executors.newCachedThreadPool());
+ val channelPipelieFactory = new CompressionChannelPipelineFactory()
+
+ new NettyServer(
+ responder,
+ new InetSocketAddress(host, port),
+ channelFactory,
+ channelPipelieFactory,
+ null)
+ } else {
+ new NettyServer(responder, new InetSocketAddress(host, port))
+ }
+ }
def onStart() {
- server.start()
+ synchronized {
+ if (server == null) {
+ server = initServer()
+ server.start()
+ } else {
+ logWarning("Flume receiver being asked to start more then once with out close")
+ }
+ }
logInfo("Flume receiver started")
}
def onStop() {
- server.close()
+ synchronized {
+ if (server != null) {
+ server.close()
+ server = null
+ }
+ }
logInfo("Flume receiver stopped")
}
override def preferredLocation = Some(host)
+
+ /** A Netty Pipeline factory that will decompress incoming data from
+ * and the Netty client and compress data going back to the client.
+ *
+ * The compression on the return is required because Flume requires
+ * a successful response to indicate it can remove the event/batch
+ * from the configured channel
+ */
+ private[streaming]
+ class CompressionChannelPipelineFactory extends ChannelPipelineFactory {
+
+ def getPipeline() = {
+ val pipeline = Channels.pipeline()
+ val encoder = new ZlibEncoder(6)
+ pipeline.addFirst("deflater", encoder)
+ pipeline.addFirst("inflater", new ZlibDecoder())
+ pipeline
+ }
+}
}
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
index 499f3560ef768..716db9fa76031 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
@@ -36,7 +36,27 @@ object FlumeUtils {
port: Int,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
): ReceiverInputDStream[SparkFlumeEvent] = {
- val inputStream = new FlumeInputDStream[SparkFlumeEvent](ssc, hostname, port, storageLevel)
+ createStream(ssc, hostname, port, storageLevel, false)
+ }
+
+ /**
+ * Create a input stream from a Flume source.
+ * @param ssc StreamingContext object
+ * @param hostname Hostname of the slave machine to which the flume data will be sent
+ * @param port Port of the slave machine to which the flume data will be sent
+ * @param storageLevel Storage level to use for storing the received objects
+ * @param enableDecompression should netty server decompress input stream
+ */
+ def createStream (
+ ssc: StreamingContext,
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel,
+ enableDecompression: Boolean
+ ): ReceiverInputDStream[SparkFlumeEvent] = {
+ val inputStream = new FlumeInputDStream[SparkFlumeEvent](
+ ssc, hostname, port, storageLevel, enableDecompression)
+
inputStream
}
@@ -66,6 +86,23 @@ object FlumeUtils {
port: Int,
storageLevel: StorageLevel
): JavaReceiverInputDStream[SparkFlumeEvent] = {
- createStream(jssc.ssc, hostname, port, storageLevel)
+ createStream(jssc.ssc, hostname, port, storageLevel, false)
+ }
+
+ /**
+ * Creates a input stream from a Flume source.
+ * @param hostname Hostname of the slave machine to which the flume data will be sent
+ * @param port Port of the slave machine to which the flume data will be sent
+ * @param storageLevel Storage level to use for storing the received objects
+ * @param enableDecompression should netty server decompress input stream
+ */
+ def createStream(
+ jssc: JavaStreamingContext,
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel,
+ enableDecompression: Boolean
+ ): JavaReceiverInputDStream[SparkFlumeEvent] = {
+ createStream(jssc.ssc, hostname, port, storageLevel, enableDecompression)
}
}
diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java
index e0ad4f1015205..3b5e0c7746b2c 100644
--- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java
+++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java
@@ -30,5 +30,7 @@ public void testFlumeStream() {
JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", 12345);
JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", 12345,
StorageLevel.MEMORY_AND_DISK_SER_2());
+ JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", 12345,
+ StorageLevel.MEMORY_AND_DISK_SER_2(), false);
}
}
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index dd287d0ef90a0..73dffef953309 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -33,15 +33,26 @@ import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuite
import org.apache.spark.streaming.util.ManualClock
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream
-class FlumeStreamSuite extends TestSuiteBase {
+import org.jboss.netty.channel.ChannelPipeline
+import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
+import org.jboss.netty.channel.socket.SocketChannel
+import org.jboss.netty.handler.codec.compression._
- val testPort = 9999
+class FlumeStreamSuite extends TestSuiteBase {
test("flume input stream") {
+ runFlumeStreamTest(false, 9998)
+ }
+
+ test("flume input compressed stream") {
+ runFlumeStreamTest(true, 9997)
+ }
+
+ def runFlumeStreamTest(enableDecompression: Boolean, testPort: Int) {
// Set up the streaming context and input streams
val ssc = new StreamingContext(conf, batchDuration)
val flumeStream: JavaReceiverInputDStream[SparkFlumeEvent] =
- FlumeUtils.createStream(ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK)
+ FlumeUtils.createStream(ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, enableDecompression)
val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]]
with SynchronizedBuffer[Seq[SparkFlumeEvent]]
val outputStream = new TestOutputStream(flumeStream.receiverInputDStream, outputBuffer)
@@ -52,8 +63,17 @@ class FlumeStreamSuite extends TestSuiteBase {
val input = Seq(1, 2, 3, 4, 5)
Thread.sleep(1000)
val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort))
- val client = SpecificRequestor.getClient(
- classOf[AvroSourceProtocol], transceiver)
+ var client: AvroSourceProtocol = null;
+
+ if (enableDecompression) {
+ client = SpecificRequestor.getClient(
+ classOf[AvroSourceProtocol],
+ new NettyTransceiver(new InetSocketAddress("localhost", testPort),
+ new CompressionChannelFactory(6)));
+ } else {
+ client = SpecificRequestor.getClient(
+ classOf[AvroSourceProtocol], transceiver)
+ }
for (i <- 0 until input.size) {
val event = new AvroFlumeEvent
@@ -64,6 +84,8 @@ class FlumeStreamSuite extends TestSuiteBase {
clock.addToTime(batchDuration.milliseconds)
}
+ Thread.sleep(1000)
+
val startTime = System.currentTimeMillis()
while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size)
@@ -85,4 +107,13 @@ class FlumeStreamSuite extends TestSuiteBase {
assert(outputBuffer(i).head.event.getHeaders.get("test") === "header")
}
}
+
+ class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory {
+ override def newChannel(pipeline:ChannelPipeline) : SocketChannel = {
+ var encoder : ZlibEncoder = new ZlibEncoder(compressionLevel);
+ pipeline.addFirst("deflater", encoder);
+ pipeline.addFirst("inflater", new ZlibDecoder());
+ super.newChannel(pipeline);
+ }
+ }
}
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index d014a7aad0fca..4762c50685a93 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-streaming-kafka_2.10
+
+ streaming-kafka
+ jarSpark Project External Kafkahttp://spark.apache.org/
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 4980208cba3b0..32c530e600ce0 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-streaming-mqtt_2.10
+
+ streaming-mqtt
+ jarSpark Project External MQTThttp://spark.apache.org/
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 7073bd4404d9c..637adb0f00da0 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-streaming-twitter_2.10
+
+ streaming-twitter
+ jarSpark Project External Twitterhttp://spark.apache.org/
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index cf306e0dca8bd..e4d758a04a4cd 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-streaming-zeromq_2.10
+
+ streaming-zeromq
+ jarSpark Project External ZeroMQhttp://spark.apache.org/
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index 955ec1a8c3033..3eade411b38b7 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -28,7 +28,11 @@
java8-tests_2.10pomSpark Project Java8 Tests POM
-
+
+
+ java8-tests
+
+
org.apache.spark
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index 22ea330b4374d..a5b162a0482e4 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -29,7 +29,11 @@
spark-ganglia-lgpl_2.10jarSpark Ganglia Integration
-
+
+
+ ganglia-lgpl
+
+
org.apache.spark
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 7d5d83e7f3bb9..7e3bcf29dcfbc 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-graphx_2.10
+
+ graphx
+ jarSpark Project GraphXhttp://spark.apache.org/
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index 4db45c9af8fae..3507f358bfb40 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -107,14 +107,16 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
/**
* Repartitions the edges in the graph according to `partitionStrategy`.
*
- * @param the partitioning strategy to use when partitioning the edges in the graph.
+ * @param partitionStrategy the partitioning strategy to use when partitioning the edges
+ * in the graph.
*/
def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED]
/**
* Repartitions the edges in the graph according to `partitionStrategy`.
*
- * @param the partitioning strategy to use when partitioning the edges in the graph.
+ * @param partitionStrategy the partitioning strategy to use when partitioning the edges
+ * in the graph.
* @param numPartitions the number of edge partitions in the new graph.
*/
def partitionBy(partitionStrategy: PartitionStrategy, numPartitions: Int): Graph[VD, ED]
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
index f1b6df9a3025e..4825d12fc27b3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -182,8 +182,8 @@ class VertexRDD[@specialized VD: ClassTag](
/**
* Left joins this RDD with another VertexRDD with the same index. This function will fail if
* both VertexRDDs do not share the same index. The resulting vertex set contains an entry for
- * each
- * vertex in `this`. If `other` is missing any vertex in this VertexRDD, `f` is passed `None`.
+ * each vertex in `this`.
+ * If `other` is missing any vertex in this VertexRDD, `f` is passed `None`.
*
* @tparam VD2 the attribute type of the other VertexRDD
* @tparam VD3 the attribute type of the resulting VertexRDD
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 3827ac8d0fd6a..502b112d31c2e 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -119,7 +119,7 @@ object RoutingTablePartition {
*/
private[graphx]
class RoutingTablePartition(
- private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) {
+ private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) extends Serializable {
/** The maximum number of edge partitions this `RoutingTablePartition` is built to join with. */
val numEdgePartitions: Int = routingTable.size
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
index 34939b24440aa..5ad6390a56c4f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
@@ -60,7 +60,8 @@ private[graphx] object VertexPartitionBase {
* `VertexPartitionBaseOpsConstructor` typeclass (for example,
* [[VertexPartition.VertexPartitionOpsConstructor]]).
*/
-private[graphx] abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag] {
+private[graphx] abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag]
+ extends Serializable {
def index: VertexIdToIndexMap
def values: Array[VD]
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
index a4f769b294010..b40aa1b417a0f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
@@ -35,7 +35,7 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
private[graphx] abstract class VertexPartitionBaseOps
[VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor]
(self: Self[VD])
- extends Logging {
+ extends Serializable with Logging {
def withIndex(index: VertexIdToIndexMap): Self[VD]
def withValues[VD2: ClassTag](values: Array[VD2]): Self[VD2]
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 28fd112f2b124..9d00f76327e4c 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -23,6 +23,7 @@ import scala.util.Random
import org.scalatest.FunSuite
import org.apache.spark.SparkConf
+import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.graphx._
@@ -124,18 +125,21 @@ class EdgePartitionSuite extends FunSuite {
assert(ep.numActives == Some(2))
}
- test("Kryo serialization") {
+ test("serialization") {
val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
- val conf = new SparkConf()
+ val javaSer = new JavaSerializer(new SparkConf())
+ val kryoSer = new KryoSerializer(new SparkConf()
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
- val s = new KryoSerializer(conf).newInstance()
- val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
- assert(aSer.srcIds.toList === a.srcIds.toList)
- assert(aSer.dstIds.toList === a.dstIds.toList)
- assert(aSer.data.toList === a.data.toList)
- assert(aSer.index != null)
- assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
+ .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+
+ for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
+ val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
+ assert(aSer.srcIds.toList === a.srcIds.toList)
+ assert(aSer.dstIds.toList === a.dstIds.toList)
+ assert(aSer.data.toList === a.data.toList)
+ assert(aSer.index != null)
+ assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
+ }
}
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index 8bf1384d514c1..f9e771a900013 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -17,9 +17,14 @@
package org.apache.spark.graphx.impl
-import org.apache.spark.graphx._
import org.scalatest.FunSuite
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.serializer.KryoSerializer
+
+import org.apache.spark.graphx._
+
class VertexPartitionSuite extends FunSuite {
test("isDefined, filter") {
@@ -116,4 +121,17 @@ class VertexPartitionSuite extends FunSuite {
assert(vp3.index.getPos(2) === -1)
}
+ test("serialization") {
+ val verts = Set((0L, 1), (1L, 1), (2L, 1))
+ val vp = VertexPartition(verts.iterator)
+ val javaSer = new JavaSerializer(new SparkConf())
+ val kryoSer = new KryoSerializer(new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+
+ for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
+ val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp))
+ assert(vpSer.iterator.toSet === verts)
+ }
+ }
}
diff --git a/mllib/pom.xml b/mllib/pom.xml
index b622f96dd7901..92b07e2357db1 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-mllib_2.10
+
+ mllib
+ jarSpark Project ML Libraryhttp://spark.apache.org/
@@ -75,6 +78,19 @@
test
+
+
+ netlib-lgpl
+
+
+ com.github.fommil.netlib
+ all
+ 1.1.2
+ pom
+
+
+
+ target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
new file mode 100644
index 0000000000000..3515461b52493
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
+import com.github.fommil.netlib.ARPACK
+import org.netlib.util.{intW, doubleW}
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Compute eigen-decomposition.
+ */
+@Experimental
+private[mllib] object EigenValueDecomposition {
+ /**
+ * Compute the leading k eigenvalues and eigenvectors on a symmetric square matrix using ARPACK.
+ * The caller needs to ensure that the input matrix is real symmetric. This function requires
+ * memory for `n*(4*k+4)` doubles.
+ *
+ * @param mul a function that multiplies the symmetric matrix with a DenseVector.
+ * @param n dimension of the square matrix (maximum Int.MaxValue).
+ * @param k number of leading eigenvalues required, 0 < k < n.
+ * @param tol tolerance of the eigs computation.
+ * @param maxIterations the maximum number of Arnoldi update iterations.
+ * @return a dense vector of eigenvalues in descending order and a dense matrix of eigenvectors
+ * (columns of the matrix).
+ * @note The number of computed eigenvalues might be smaller than k when some Ritz values do not
+ * satisfy the convergence criterion specified by tol (see ARPACK Users Guide, Chapter 4.6
+ * for more details). The maximum number of Arnoldi update iterations is set to 300 in this
+ * function.
+ */
+ private[mllib] def symmetricEigs(
+ mul: BDV[Double] => BDV[Double],
+ n: Int,
+ k: Int,
+ tol: Double,
+ maxIterations: Int): (BDV[Double], BDM[Double]) = {
+ // TODO: remove this function and use eigs in breeze when switching breeze version
+ require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n")
+
+ val arpack = ARPACK.getInstance()
+
+ // tolerance used in stopping criterion
+ val tolW = new doubleW(tol)
+ // number of desired eigenvalues, 0 < nev < n
+ val nev = new intW(k)
+ // nev Lanczos vectors are generated in the first iteration
+ // ncv-nev Lanczos vectors are generated in each subsequent iteration
+ // ncv must be smaller than n
+ val ncv = math.min(2 * k, n)
+
+ // "I" for standard eigenvalue problem, "G" for generalized eigenvalue problem
+ val bmat = "I"
+ // "LM" : compute the NEV largest (in magnitude) eigenvalues
+ val which = "LM"
+
+ var iparam = new Array[Int](11)
+ // use exact shift in each iteration
+ iparam(0) = 1
+ // maximum number of Arnoldi update iterations, or the actual number of iterations on output
+ iparam(2) = maxIterations
+ // Mode 1: A*x = lambda*x, A symmetric
+ iparam(6) = 1
+
+ var ido = new intW(0)
+ var info = new intW(0)
+ var resid = new Array[Double](n)
+ var v = new Array[Double](n * ncv)
+ var workd = new Array[Double](n * 3)
+ var workl = new Array[Double](ncv * (ncv + 8))
+ var ipntr = new Array[Int](11)
+
+ // call ARPACK's reverse communication, first iteration with ido = 0
+ arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, workd,
+ workl, workl.length, info)
+
+ val w = BDV(workd)
+
+ // ido = 99 : done flag in reverse communication
+ while (ido.`val` != 99) {
+ if (ido.`val` != -1 && ido.`val` != 1) {
+ throw new IllegalStateException("ARPACK returns ido = " + ido.`val` +
+ " This flag is not compatible with Mode 1: A*x = lambda*x, A symmetric.")
+ }
+ // multiply working vector with the matrix
+ val inputOffset = ipntr(0) - 1
+ val outputOffset = ipntr(1) - 1
+ val x = w.slice(inputOffset, inputOffset + n)
+ val y = w.slice(outputOffset, outputOffset + n)
+ y := mul(x)
+ // call ARPACK's reverse communication
+ arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr,
+ workd, workl, workl.length, info)
+ }
+
+ if (info.`val` != 0) {
+ info.`val` match {
+ case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " Maximum number of iterations taken. (Refer ARPACK user guide for details)")
+ case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " No shifts could be applied. Try to increase NCV. " +
+ "(Refer ARPACK user guide for details)")
+ case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ " Please refer ARPACK user guide for error message.")
+ }
+ }
+
+ val d = new Array[Double](nev.`val`)
+ val select = new Array[Boolean](ncv)
+ // copy the Ritz vectors
+ val z = java.util.Arrays.copyOfRange(v, 0, nev.`val` * n)
+
+ // call ARPACK's post-processing for eigenvectors
+ arpack.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid, ncv, v, n,
+ iparam, ipntr, workd, workl, workl.length, info)
+
+ // number of computed eigenvalues, might be smaller than k
+ val computed = iparam(4)
+
+ val eigenPairs = java.util.Arrays.copyOfRange(d, 0, computed).zipWithIndex.map { r =>
+ (r._1, java.util.Arrays.copyOfRange(z, r._2 * n, r._2 * n + n))
+ }
+
+ // sort the eigen-pairs in descending order
+ val sortedEigenPairs = eigenPairs.sortBy(- _._1)
+
+ // copy eigenvectors in descending order of eigenvalues
+ val sortedU = BDM.zeros[Double](n, computed)
+ sortedEigenPairs.zipWithIndex.foreach { r =>
+ val b = r._2 * n
+ var i = 0
+ while (i < n) {
+ sortedU.data(b + i) = r._1._2(i)
+ i += 1
+ }
+ }
+
+ (BDV[Double](sortedEigenPairs.map(_._1)), sortedU)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index c818a0b9c3e43..77b3e8c714997 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -62,7 +62,7 @@ trait Vector extends Serializable {
* Gets the value of the ith element.
* @param i index
*/
- private[mllib] def apply(i: Int): Double = toBreeze(i)
+ def apply(i: Int): Double = toBreeze(i)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 695e03b736baf..f4c403bc7861c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -17,9 +17,10 @@
package org.apache.spark.mllib.linalg.distributed
-import java.util
+import java.util.Arrays
-import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
+import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV}
+import breeze.linalg.{svd => brzSvd, axpy => brzAxpy}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}
@@ -27,138 +28,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
-import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
-
-/**
- * Column statistics aggregator implementing
- * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
- * together with add() and merge() function.
- * A numerically stable algorithm is implemented to compute sample mean and variance:
- *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
- * Zero elements (including explicit zero values) are skipped when calling add() and merge(),
- * to have time complexity O(nnz) instead of O(n) for each column.
- */
-private class ColumnStatisticsAggregator(private val n: Int)
- extends MultivariateStatisticalSummary with Serializable {
-
- private val currMean: BDV[Double] = BDV.zeros[Double](n)
- private val currM2n: BDV[Double] = BDV.zeros[Double](n)
- private var totalCnt = 0.0
- private val nnz: BDV[Double] = BDV.zeros[Double](n)
- private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
- private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
-
- override def mean: Vector = {
- val realMean = BDV.zeros[Double](n)
- var i = 0
- while (i < n) {
- realMean(i) = currMean(i) * nnz(i) / totalCnt
- i += 1
- }
- Vectors.fromBreeze(realMean)
- }
-
- override def variance: Vector = {
- val realVariance = BDV.zeros[Double](n)
-
- val denominator = totalCnt - 1.0
-
- // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
- if (denominator > 0.0) {
- val deltaMean = currMean
- var i = 0
- while (i < currM2n.size) {
- realVariance(i) =
- currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
- realVariance(i) /= denominator
- i += 1
- }
- }
-
- Vectors.fromBreeze(realVariance)
- }
-
- override def count: Long = totalCnt.toLong
-
- override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
-
- override def max: Vector = {
- var i = 0
- while (i < n) {
- if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
- i += 1
- }
- Vectors.fromBreeze(currMax)
- }
-
- override def min: Vector = {
- var i = 0
- while (i < n) {
- if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
- i += 1
- }
- Vectors.fromBreeze(currMin)
- }
-
- /**
- * Aggregates a row.
- */
- def add(currData: BV[Double]): this.type = {
- currData.activeIterator.foreach {
- case (_, 0.0) => // Skip explicit zero elements.
- case (i, value) =>
- if (currMax(i) < value) {
- currMax(i) = value
- }
- if (currMin(i) > value) {
- currMin(i) = value
- }
-
- val tmpPrevMean = currMean(i)
- currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
- currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
-
- nnz(i) += 1.0
- }
-
- totalCnt += 1.0
- this
- }
-
- /**
- * Merges another aggregator.
- */
- def merge(other: ColumnStatisticsAggregator): this.type = {
- require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
-
- totalCnt += other.totalCnt
- val deltaMean = currMean - other.currMean
-
- var i = 0
- while (i < n) {
- // merge mean together
- if (other.currMean(i) != 0.0) {
- currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2n together
- if (nnz(i) + other.nnz(i) != 0.0) {
- currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
- (nnz(i) + other.nnz(i))
- }
- if (currMax(i) < other.currMax(i)) {
- currMax(i) = other.currMax(i)
- }
- if (currMin(i) > other.currMin(i)) {
- currMin(i) = other.currMin(i)
- }
- i += 1
- }
-
- nnz += other.nnz
- this
- }
-}
+import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
/**
* :: Experimental ::
@@ -200,6 +70,32 @@ class RowMatrix(
nRows
}
+ /**
+ * Multiplies the Gramian matrix `A^T A` by a dense vector on the right without computing `A^T A`.
+ *
+ * @param v a dense vector whose length must match the number of columns of this matrix
+ * @return a dense vector representing the product
+ */
+ private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = {
+ val n = numCols().toInt
+ val vbr = rows.context.broadcast(v)
+ rows.aggregate(BDV.zeros[Double](n))(
+ seqOp = (U, r) => {
+ val rBrz = r.toBreeze
+ val a = rBrz.dot(vbr.value)
+ rBrz match {
+ // use specialized axpy for better performance
+ case _: BDV[_] => brzAxpy(a, rBrz.asInstanceOf[BDV[Double]], U)
+ case _: BSV[_] => brzAxpy(a, rBrz.asInstanceOf[BSV[Double]], U)
+ case _ => throw new UnsupportedOperationException(
+ s"Do not support vector operation from type ${rBrz.getClass.getName}.")
+ }
+ U
+ },
+ combOp = (U1, U2) => U1 += U2
+ )
+ }
+
/**
* Computes the Gramian matrix `A^T A`.
*/
@@ -220,50 +116,135 @@ class RowMatrix(
}
/**
- * Computes the singular value decomposition of this matrix.
- * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'.
+ * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This
+ * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k
+ * singular values, U and V contain the corresponding singular vectors.
*
- * There is no restriction on m, but we require `n^2` doubles to fit in memory.
- * Further, n should be less than m.
-
- * The decomposition is computed by first computing A'A = V S^2 V',
- * computing svd locally on that (since n x n is small), from which we recover S and V.
- * Then we compute U via easy matrix multiplication as U = A * (V * S^-1).
- * Note that this approach requires `O(n^3)` time on the master node.
+ * At most k largest non-zero singular values and associated vectors are returned. If there are k
+ * such values, then the dimensions of the return will be:
+ * - U is a RowMatrix of size m x k that satisfies U' * U = eye(k),
+ * - s is a Vector of size k, holding the singular values in descending order,
+ * - V is a Matrix of size n x k that satisfies V' * V = eye(k).
+ *
+ * We assume n is smaller than m. The singular values and the right singular vectors are derived
+ * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix
+ * storing the right singular vectors, is computed via matrix multiplication as
+ * U = A * (V * S^-1^), if requested by user. The actual method to use is determined
+ * automatically based on the cost:
+ * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian
+ * matrix first and then compute its top eigenvalues and eigenvectors locally on the driver.
+ * This requires a single pass with O(n^2^) storage on each executor and on the driver, and
+ * O(n^2^ k) time on the driver.
+ * - Otherwise, we compute (A' * A) * v in a distributive way and send it to ARPACK's DSAUPD to
+ * compute (A' * A)'s top eigenvalues and eigenvectors on the driver node. This requires O(k)
+ * passes, O(n) storage on each executor, and O(n k) storage on the driver.
*
- * At most k largest non-zero singular values and associated vectors are returned.
- * If there are k such values, then the dimensions of the return will be:
+ * Several internal parameters are set to default values. The reciprocal condition number rCond
+ * is set to 1e-9. All singular values smaller than rCond * sigma(0) are treated as zeros, where
+ * sigma(0) is the largest singular value. The maximum number of Arnoldi update iterations for
+ * ARPACK is set to 300 or k * 3, whichever is larger. The numerical tolerance for ARPACK's
+ * eigen-decomposition is set to 1e-10.
*
- * U is a RowMatrix of size m x k that satisfies U'U = eye(k),
- * s is a Vector of size k, holding the singular values in descending order,
- * and V is a Matrix of size n x k that satisfies V'V = eye(k).
+ * @note The conditions that decide which method to use internally and the default parameters are
+ * subject to change.
*
- * @param k number of singular values to keep. We might return less than k if there are
- * numerically zero singular values. See rCond.
+ * @param k number of leading singular values to keep (0 < k <= n). It might return less than k if
+ * there are numerically zero singular values or there are not enough Ritz values
+ * converged before the maximum number of Arnoldi update iterations is reached (in case
+ * that matrix A is ill-conditioned).
* @param computeU whether to compute U
* @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0)
* are treated as zero, where sigma(0) is the largest singular value.
- * @return SingularValueDecomposition(U, s, V)
+ * @return SingularValueDecomposition(U, s, V). U = null if computeU = false.
*/
def computeSVD(
k: Int,
computeU: Boolean = false,
rCond: Double = 1e-9): SingularValueDecomposition[RowMatrix, Matrix] = {
+ // maximum number of Arnoldi update iterations for invoking ARPACK
+ val maxIter = math.max(300, k * 3)
+ // numerical tolerance for invoking ARPACK
+ val tol = 1e-10
+ computeSVD(k, computeU, rCond, maxIter, tol, "auto")
+ }
+
+ /**
+ * The actual SVD implementation, visible for testing.
+ *
+ * @param k number of leading singular values to keep (0 < k <= n)
+ * @param computeU whether to compute U
+ * @param rCond the reciprocal condition number
+ * @param maxIter max number of iterations (if ARPACK is used)
+ * @param tol termination tolerance (if ARPACK is used)
+ * @param mode computation mode (auto: determine automatically which mode to use,
+ * local-svd: compute gram matrix and computes its full SVD locally,
+ * local-eigs: compute gram matrix and computes its top eigenvalues locally,
+ * dist-eigs: compute the top eigenvalues of the gram matrix distributively)
+ * @return SingularValueDecomposition(U, s, V). U = null if computeU = false.
+ */
+ private[mllib] def computeSVD(
+ k: Int,
+ computeU: Boolean,
+ rCond: Double,
+ maxIter: Int,
+ tol: Double,
+ mode: String): SingularValueDecomposition[RowMatrix, Matrix] = {
val n = numCols().toInt
- require(k > 0 && k <= n, s"Request up to n singular values k=$k n=$n.")
+ require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.")
- val G = computeGramianMatrix()
+ object SVDMode extends Enumeration {
+ val LocalARPACK, LocalLAPACK, DistARPACK = Value
+ }
+
+ val computeMode = mode match {
+ case "auto" =>
+ // TODO: The conditions below are not fully tested.
+ if (n < 100 || k > n / 2) {
+ // If n is small or k is large compared with n, we better compute the Gramian matrix first
+ // and then compute its eigenvalues locally, instead of making multiple passes.
+ if (k < n / 3) {
+ SVDMode.LocalARPACK
+ } else {
+ SVDMode.LocalLAPACK
+ }
+ } else {
+ // If k is small compared with n, we use ARPACK with distributed multiplication.
+ SVDMode.DistARPACK
+ }
+ case "local-svd" => SVDMode.LocalLAPACK
+ case "local-eigs" => SVDMode.LocalARPACK
+ case "dist-eigs" => SVDMode.DistARPACK
+ case _ => throw new IllegalArgumentException(s"Do not support mode $mode.")
+ }
+
+ // Compute the eigen-decomposition of A' * A.
+ val (sigmaSquares: BDV[Double], u: BDM[Double]) = computeMode match {
+ case SVDMode.LocalARPACK =>
+ require(k < n, s"k must be smaller than n in local-eigs mode but got k=$k and n=$n.")
+ val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]]
+ EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter)
+ case SVDMode.LocalLAPACK =>
+ val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]]
+ val (uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G)
+ (sigmaSquaresFull, uFull)
+ case SVDMode.DistARPACK =>
+ require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.")
+ EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter)
+ }
- // TODO: Use sparse SVD instead.
- val (u: BDM[Double], sigmaSquares: BDV[Double], v: BDM[Double]) =
- brzSvd(G.toBreeze.asInstanceOf[BDM[Double]])
val sigmas: BDV[Double] = brzSqrt(sigmaSquares)
- // Determine effective rank.
+ // Determine the effective rank.
val sigma0 = sigmas(0)
val threshold = rCond * sigma0
var i = 0
- while (i < k && sigmas(i) >= threshold) {
+ // sigmas might have a length smaller than k, if some Ritz values do not satisfy the convergence
+ // criterion specified by tol after max number of iterations.
+ // Thus use i < min(k, sigmas.length) instead of i < k.
+ if (sigmas.length < k) {
+ logWarning(s"Requested $k singular values but only found ${sigmas.length} converged.")
+ }
+ while (i < math.min(k, sigmas.length) && sigmas(i) >= threshold) {
i += 1
}
val sk = i
@@ -272,12 +253,12 @@ class RowMatrix(
logWarning(s"Requested $k singular values but only found $sk nonzeros.")
}
- val s = Vectors.dense(util.Arrays.copyOfRange(sigmas.data, 0, sk))
- val V = Matrices.dense(n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk))
+ val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk))
+ val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk))
if (computeU) {
// N = Vk * Sk^{-1}
- val N = new BDM[Double](n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk))
+ val N = new BDM[Double](n, sk, Arrays.copyOfRange(u.data, 0, n * sk))
var i = 0
var j = 0
while (j < sk) {
@@ -364,7 +345,7 @@ class RowMatrix(
if (k == n) {
Matrices.dense(n, k, u.data)
} else {
- Matrices.dense(n, k, util.Arrays.copyOfRange(u.data, 0, n * k))
+ Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k))
}
}
@@ -372,8 +353,7 @@ class RowMatrix(
* Computes column-wise summary statistics.
*/
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
- val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
- val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
+ val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
@@ -390,15 +370,24 @@ class RowMatrix(
*/
def multiply(B: Matrix): RowMatrix = {
val n = numCols().toInt
+ val k = B.numCols
require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}")
require(B.isInstanceOf[DenseMatrix],
s"Only support dense matrix at this time but found ${B.getClass.getName}.")
- val Bb = rows.context.broadcast(B)
+ val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray)
val AB = rows.mapPartitions({ iter =>
- val Bi = Bb.value.toBreeze.asInstanceOf[BDM[Double]]
- iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze))
+ val Bi = Bb.value
+ iter.map(row => {
+ val v = BDV.zeros[Double](k)
+ var i = 0
+ while (i < k) {
+ v(i) = row.toBreeze.dot(new BDV(Bi, i * n, 1, n))
+ i += 1
+ }
+ Vectors.fromBreeze(v)
+ })
}, preservesPartitioning = true)
new RowMatrix(AB, nRows, B.numCols)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
new file mode 100644
index 0000000000000..5105b5c37aaaa
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
+
+/**
+ * :: DeveloperApi ::
+ * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
+ * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
+ * format in a online fashion.
+ *
+ * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
+ * the corresponding joint dataset.
+ *
+ * A numerically stable algorithm is implemented to compute sample mean and variance:
+ * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
+ * Zero elements (including explicit zero values) are skipped when calling add(),
+ * to have time complexity O(nnz) instead of O(n) for each column.
+ */
+@DeveloperApi
+class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
+
+ private var n = 0
+ private var currMean: BDV[Double] = _
+ private var currM2n: BDV[Double] = _
+ private var totalCnt: Long = 0
+ private var nnz: BDV[Double] = _
+ private var currMax: BDV[Double] = _
+ private var currMin: BDV[Double] = _
+
+ /**
+ * Add a new sample to this summarizer, and update the statistical summary.
+ *
+ * @param sample The sample in dense/sparse vector format to be added into this summarizer.
+ * @return This MultivariateOnlineSummarizer object.
+ */
+ def add(sample: Vector): this.type = {
+ if (n == 0) {
+ require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
+ n = sample.toBreeze.length
+
+ currMean = BDV.zeros[Double](n)
+ currM2n = BDV.zeros[Double](n)
+ nnz = BDV.zeros[Double](n)
+ currMax = BDV.fill(n)(Double.MinValue)
+ currMin = BDV.fill(n)(Double.MaxValue)
+ }
+
+ require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
+ s" Expecting $n but got ${sample.toBreeze.length}.")
+
+ sample.toBreeze.activeIterator.foreach {
+ case (_, 0.0) => // Skip explicit zero elements.
+ case (i, value) =>
+ if (currMax(i) < value) {
+ currMax(i) = value
+ }
+ if (currMin(i) > value) {
+ currMin(i) = value
+ }
+
+ val tmpPrevMean = currMean(i)
+ currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
+ currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
+
+ nnz(i) += 1.0
+ }
+
+ totalCnt += 1
+ this
+ }
+
+ /**
+ * Merge another MultivariateOnlineSummarizer, and update the statistical summary.
+ * (Note that it's in place merging; as a result, `this` object will be modified.)
+ *
+ * @param other The other MultivariateOnlineSummarizer to be merged.
+ * @return This MultivariateOnlineSummarizer object.
+ */
+ def merge(other: MultivariateOnlineSummarizer): this.type = {
+ if (this.totalCnt != 0 && other.totalCnt != 0) {
+ require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
+ s"Expecting $n but got ${other.n}.")
+ totalCnt += other.totalCnt
+ val deltaMean: BDV[Double] = currMean - other.currMean
+ var i = 0
+ while (i < n) {
+ // merge mean together
+ if (other.currMean(i) != 0.0) {
+ currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
+ (nnz(i) + other.nnz(i))
+ }
+ // merge m2n together
+ if (nnz(i) + other.nnz(i) != 0.0) {
+ currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
+ (nnz(i) + other.nnz(i))
+ }
+ if (currMax(i) < other.currMax(i)) {
+ currMax(i) = other.currMax(i)
+ }
+ if (currMin(i) > other.currMin(i)) {
+ currMin(i) = other.currMin(i)
+ }
+ i += 1
+ }
+ nnz += other.nnz
+ } else if (totalCnt == 0 && other.totalCnt != 0) {
+ this.n = other.n
+ this.currMean = other.currMean.copy
+ this.currM2n = other.currM2n.copy
+ this.totalCnt = other.totalCnt
+ this.nnz = other.nnz.copy
+ this.currMax = other.currMax.copy
+ this.currMin = other.currMin.copy
+ }
+ this
+ }
+
+ override def mean: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ val realMean = BDV.zeros[Double](n)
+ var i = 0
+ while (i < n) {
+ realMean(i) = currMean(i) * (nnz(i) / totalCnt)
+ i += 1
+ }
+ Vectors.fromBreeze(realMean)
+ }
+
+ override def variance: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ val realVariance = BDV.zeros[Double](n)
+
+ val denominator = totalCnt - 1.0
+
+ // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
+ if (denominator > 0.0) {
+ val deltaMean = currMean
+ var i = 0
+ while (i < currM2n.size) {
+ realVariance(i) =
+ currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
+ realVariance(i) /= denominator
+ i += 1
+ }
+ }
+
+ Vectors.fromBreeze(realVariance)
+ }
+
+ override def count: Long = totalCnt
+
+ override def numNonzeros: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ Vectors.fromBreeze(nnz)
+ }
+
+ override def max: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ var i = 0
+ while (i < n) {
+ if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
+ i += 1
+ }
+ Vectors.fromBreeze(currMax)
+ }
+
+ override def min: Vector = {
+ require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+
+ var i = 0
+ while (i < n) {
+ if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
+ i += 1
+ }
+ Vectors.fromBreeze(currMin)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 3b13e52a7b445..74d5d7ba10960 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -807,10 +807,10 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
- binData(shift + (2 *(numBins - 2 - splitIndex))) +
+ binData(shift + (2 *(numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
+ binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
splitIndex += 1
@@ -855,13 +855,13 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
- binData(shift + (3 * (numBins - 2 - splitIndex))) +
+ binData(shift + (3 * (numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
+ binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
- binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
+ binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
splitIndex += 1
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index c9f9acf4c1335..a961f89456a18 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -96,37 +96,44 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
test("svd of a full-rank matrix") {
for (mat <- Seq(denseMat, sparseMat)) {
- val localMat = mat.toBreeze()
- val (localU, localSigma, localVt) = brzSvd(localMat)
- val localV: BDM[Double] = localVt.t.toDenseMatrix
- for (k <- 1 to n) {
- val svd = mat.computeSVD(k, computeU = true)
- val U = svd.U
- val s = svd.s
- val V = svd.V
- assert(U.numRows() === m)
- assert(U.numCols() === k)
- assert(s.size === k)
- assert(V.numRows === n)
- assert(V.numCols === k)
- assertColumnEqualUpToSign(U.toBreeze(), localU, k)
- assertColumnEqualUpToSign(V.toBreeze.asInstanceOf[BDM[Double]], localV, k)
- assert(closeToZero(s.toBreeze.asInstanceOf[BDV[Double]] - localSigma(0 until k)))
+ for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) {
+ val localMat = mat.toBreeze()
+ val (localU, localSigma, localVt) = brzSvd(localMat)
+ val localV: BDM[Double] = localVt.t.toDenseMatrix
+ for (k <- 1 to n) {
+ val skip = (mode == "local-eigs" || mode == "dist-eigs") && k == n
+ if (!skip) {
+ val svd = mat.computeSVD(k, computeU = true, 1e-9, 300, 1e-10, mode)
+ val U = svd.U
+ val s = svd.s
+ val V = svd.V
+ assert(U.numRows() === m)
+ assert(U.numCols() === k)
+ assert(s.size === k)
+ assert(V.numRows === n)
+ assert(V.numCols === k)
+ assertColumnEqualUpToSign(U.toBreeze(), localU, k)
+ assertColumnEqualUpToSign(V.toBreeze.asInstanceOf[BDM[Double]], localV, k)
+ assert(closeToZero(s.toBreeze.asInstanceOf[BDV[Double]] - localSigma(0 until k)))
+ }
+ }
+ val svdWithoutU = mat.computeSVD(1, computeU = false, 1e-9, 300, 1e-10, mode)
+ assert(svdWithoutU.U === null)
}
- val svdWithoutU = mat.computeSVD(n)
- assert(svdWithoutU.U === null)
}
}
test("svd of a low-rank matrix") {
- val rows = sc.parallelize(Array.fill(4)(Vectors.dense(1.0, 1.0)), 2)
- val mat = new RowMatrix(rows, 4, 2)
- val svd = mat.computeSVD(2, computeU = true)
- assert(svd.s.size === 1, "should not return zero singular values")
- assert(svd.U.numRows() === 4)
- assert(svd.U.numCols() === 1)
- assert(svd.V.numRows === 2)
- assert(svd.V.numCols === 1)
+ val rows = sc.parallelize(Array.fill(4)(Vectors.dense(1.0, 1.0, 1.0)), 2)
+ val mat = new RowMatrix(rows, 4, 3)
+ for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) {
+ val svd = mat.computeSVD(2, computeU = true, 1e-6, 300, 1e-10, mode)
+ assert(svd.s.size === 1, s"should not return zero singular values but got ${svd.s}")
+ assert(svd.U.numRows() === 4)
+ assert(svd.U.numCols() === 1)
+ assert(svd.V.numRows === 3)
+ assert(svd.V.numCols === 1)
+ }
}
def closeToZero(G: BDM[Double]): Boolean = {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
new file mode 100644
index 0000000000000..4b7b019d820b4
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.TestingUtils._
+
+class MultivariateOnlineSummarizerSuite extends FunSuite {
+
+ test("basic error handing") {
+ val summarizer = new MultivariateOnlineSummarizer
+
+ assert(summarizer.count === 0, "should be zero since nothing is added.")
+
+ withClue("Getting numNonzeros from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.numNonzeros
+ }
+ }
+
+ withClue("Getting variance from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.variance
+ }
+ }
+
+ withClue("Getting mean from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.mean
+ }
+ }
+
+ withClue("Getting max from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.max
+ }
+ }
+
+ withClue("Getting min from empty summarizer should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.min
+ }
+ }
+
+ summarizer.add(Vectors.dense(-1.0, 2.0, 6.0)).add(Vectors.sparse(3, Seq((0, -2.0), (1, 6.0))))
+
+ withClue("Adding a new dense sample with different array size should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.add(Vectors.dense(3.0, 1.0))
+ }
+ }
+
+ withClue("Adding a new sparse sample with different array size should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.add(Vectors.sparse(5, Seq((0, -2.0), (1, 6.0))))
+ }
+ }
+
+ val summarizer2 = (new MultivariateOnlineSummarizer).add(Vectors.dense(1.0, -2.0, 0.0, 4.0))
+ withClue("Merging a new summarizer with different dimensions should throw exception.") {
+ intercept[IllegalArgumentException] {
+ summarizer.merge(summarizer2)
+ }
+ }
+ }
+
+ test("dense vector input") {
+ // For column 2, the maximum will be 0.0, and it's not explicitly added since we ignore all
+ // the zeros; it's a case we need to test. For column 3, the minimum will be 0.0 which we
+ // need to test as well.
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(-1.0, 0.0, 6.0))
+ .add(Vectors.dense(3.0, -3.0, 0.0))
+
+ assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+
+ assert(summarizer.count === 2)
+ }
+
+ test("sparse vector input") {
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0))))
+ .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0))))
+
+ assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch")
+
+ assert(summarizer.count === 2)
+ }
+
+ test("mixing dense and sparse vector input") {
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
+ .add(Vectors.dense(0.0, -1.0, -3.0))
+ .add(Vectors.sparse(3, Seq((1, -5.1))))
+ .add(Vectors.dense(3.8, 0.0, 1.9))
+ .add(Vectors.dense(1.7, -0.6, 0.0))
+ .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
+
+ assert(summarizer.mean.almostEquals(
+ Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+
+ assert(summarizer.count === 6)
+ }
+
+ test("merging two summarizers") {
+ val summarizer1 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
+ .add(Vectors.dense(0.0, -1.0, -3.0))
+
+ val summarizer2 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((1, -5.1))))
+ .add(Vectors.dense(3.8, 0.0, 1.9))
+ .add(Vectors.dense(1.7, -0.6, 0.0))
+ .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
+
+ val summarizer = summarizer1.merge(summarizer2)
+
+ assert(summarizer.mean.almostEquals(
+ Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch")
+
+ assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch")
+
+ assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch")
+
+ assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch")
+
+ assert(summarizer.variance.almostEquals(
+ Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch")
+
+ assert(summarizer.count === 6)
+ }
+
+ test("merging summarizer with empty summarizer") {
+ // If one of two is non-empty, this should return the non-empty summarizer.
+ // If both of them are empty, then just return the empty summarizer.
+ val summarizer1 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(0.0, -1.0, -3.0)).merge(new MultivariateOnlineSummarizer)
+ assert(summarizer1.count === 1)
+
+ val summarizer2 = (new MultivariateOnlineSummarizer)
+ .merge((new MultivariateOnlineSummarizer).add(Vectors.dense(0.0, -1.0, -3.0)))
+ assert(summarizer2.count === 1)
+
+ val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer)
+ assert(summarizer3.count === 0)
+
+ assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+
+ assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch")
+
+ assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+
+ assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch")
+
+ assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+
+ assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch")
+
+ assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch")
+
+ assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+
+ assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch")
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 35e92d71dc63f..bcb11876b8f4f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -253,8 +253,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
- assert(stats.predict > 0.4)
- assert(stats.predict < 0.5)
+ assert(stats.predict > 0.5)
+ assert(stats.predict < 0.7)
assert(stats.impurity > 0.2)
}
@@ -280,8 +280,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
- assert(stats.predict > 0.4)
- assert(stats.predict < 0.5)
+ assert(stats.predict > 0.5)
+ assert(stats.predict < 0.7)
assert(stats.impurity > 0.2)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
new file mode 100644
index 0000000000000..64b1ba7527183
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.apache.spark.mllib.linalg.Vector
+
+object TestingUtils {
+
+ implicit class DoubleWithAlmostEquals(val x: Double) {
+ // An improved version of AlmostEquals would always divide by the larger number.
+ // This will avoid the problem of diving by zero.
+ def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = {
+ if(x == y) {
+ true
+ } else if(math.abs(x) > math.abs(y)) {
+ math.abs(x - y) / math.abs(x) < epsilon
+ } else {
+ math.abs(x - y) / math.abs(y) < epsilon
+ }
+ }
+ }
+
+ implicit class VectorWithAlmostEquals(val x: Vector) {
+ def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = {
+ x.toArray.corresponds(y.toArray) {
+ _.almostEquals(_, epsilon)
+ }
+ }
+ }
+}
diff --git a/pom.xml b/pom.xml
index 910d91811e194..fce7fd96d1853 100644
--- a/pom.xml
+++ b/pom.xml
@@ -110,7 +110,7 @@
UTF-81.6
-
+ spark2.10.42.100.18.1
@@ -540,6 +540,10 @@
org.mortbay.jettyservlet-api-2.5
+
+ javax.servlet
+ servlet-api
+ junitjunit
@@ -623,6 +627,10 @@
hadoop-yarn-api${yarn.version}
+
+ javax.servlet
+ servlet-api
+ asmasm
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index bb2d73741c3bf..034ba6a7bf50f 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -15,13 +15,16 @@
* limitations under the License.
*/
+import sbt._
+import sbt.Keys.version
+
import com.typesafe.tools.mima.core._
import com.typesafe.tools.mima.core.MissingClassProblem
import com.typesafe.tools.mima.core.MissingTypesProblem
import com.typesafe.tools.mima.core.ProblemFilters._
import com.typesafe.tools.mima.plugin.MimaKeys.{binaryIssueFilters, previousArtifact}
import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings
-import sbt._
+
object MimaBuild {
@@ -53,7 +56,7 @@ object MimaBuild {
excludePackage("org.apache.spark." + packageName)
}
- def ignoredABIProblems(base: File) = {
+ def ignoredABIProblems(base: File, currentSparkVersion: String) = {
// Excludes placed here will be used for all Spark versions
val defaultExcludes = Seq()
@@ -77,11 +80,16 @@ object MimaBuild {
}
defaultExcludes ++ ignoredClasses.flatMap(excludeClass) ++
- ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes
+ ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes(currentSparkVersion)
+ }
+
+ def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
+ val organization = "org.apache.spark"
+ val previousSparkVersion = "1.0.0"
+ val fullId = "spark-" + projectRef.project + "_2.10"
+ mimaDefaultSettings ++
+ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
+ binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value))
}
- def mimaSettings(sparkHome: File) = mimaDefaultSettings ++ Seq(
- previousArtifact := None,
- binaryIssueFilters ++= ignoredABIProblems(sparkHome)
- )
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 1621833e124f5..d67c6571a0623 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -31,8 +31,8 @@ import com.typesafe.tools.mima.core._
* MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap")
*/
object MimaExcludes {
- val excludes =
- SparkBuild.SPARK_VERSION match {
+ def excludes(version: String) =
+ version match {
case v if v.startsWith("1.1") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
@@ -64,6 +64,9 @@ object MimaExcludes {
"org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$"
+ "createZero$1")
) ++
+ Seq(
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this")
+ ) ++
Seq( // Ignore some private methods in ALS.
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
@@ -72,6 +75,7 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7")
) ++
+ MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++
MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 8d1659d04e4f2..44abbc152f99f 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -15,524 +15,164 @@
* limitations under the License.
*/
-import sbt._
-import sbt.Classpaths.publishTask
-import sbt.Keys._
-import sbtassembly.Plugin._
-import AssemblyKeys._
import scala.util.Properties
-import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings}
-import com.typesafe.tools.mima.plugin.MimaKeys.previousArtifact
-import sbtunidoc.Plugin._
-import UnidocKeys._
-
import scala.collection.JavaConversions._
-// For Sonatype publishing
-// import com.jsuereth.pgp.sbtplugin.PgpKeys._
-
-object SparkBuild extends Build {
- val SPARK_VERSION = "1.1.0-SNAPSHOT"
- val SPARK_VERSION_SHORT = SPARK_VERSION.replaceAll("-SNAPSHOT", "")
-
- // Hadoop version to build against. For example, "1.0.4" for Apache releases, or
- // "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set
- // through the environment variables SPARK_HADOOP_VERSION and SPARK_YARN.
- val DEFAULT_HADOOP_VERSION = "1.0.4"
-
- // Whether the Hadoop version to build against is 2.2.x, or a variant of it. This can be set
- // through the SPARK_IS_NEW_HADOOP environment variable.
- val DEFAULT_IS_NEW_HADOOP = false
-
- val DEFAULT_YARN = false
-
- val DEFAULT_HIVE = false
-
- // HBase version; set as appropriate.
- val HBASE_VERSION = "0.94.6"
-
- // Target JVM version
- val SCALAC_JVM_VERSION = "jvm-1.6"
- val JAVAC_JVM_VERSION = "1.6"
-
- lazy val root = Project("root", file("."), settings = rootSettings) aggregate(allProjects: _*)
-
- lazy val core = Project("core", file("core"), settings = coreSettings)
-
- /** Following project only exists to pull previous artifacts of Spark for generating
- Mima ignores. For more information see: SPARK 2071 */
- lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings)
-
- def replDependencies = Seq[ProjectReference](core, graphx, bagel, mllib, sql) ++ maybeHiveRef
-
- lazy val repl = Project("repl", file("repl"), settings = replSettings)
- .dependsOn(replDependencies.map(a => a: sbt.ClasspathDep[sbt.ProjectReference]): _*)
-
- lazy val tools = Project("tools", file("tools"), settings = toolsSettings) dependsOn(core) dependsOn(streaming)
+import sbt._
+import sbt.Keys._
+import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings}
+import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
+import net.virtualvoid.sbt.graph.Plugin.graphSettings
- lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn(core)
+object BuildCommons {
- lazy val graphx = Project("graphx", file("graphx"), settings = graphxSettings) dependsOn(core)
+ private val buildLocation = file(".").getAbsoluteFile.getParentFile
- lazy val catalyst = Project("catalyst", file("sql/catalyst"), settings = catalystSettings) dependsOn(core)
+ val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming,
+ streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) =
+ Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql",
+ "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
+ "streaming-zeromq").map(ProjectRef(buildLocation, _))
- lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core) dependsOn(catalyst % "compile->compile;test->test")
+ val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) =
+ Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl")
+ .map(ProjectRef(buildLocation, _))
- lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql)
+ val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples")
+ .map(ProjectRef(buildLocation, _))
- lazy val maybeHive: Seq[ClasspathDependency] = if (isHiveEnabled) Seq(hive) else Seq()
- lazy val maybeHiveRef: Seq[ProjectReference] = if (isHiveEnabled) Seq(hive) else Seq()
+ val tools = "tools"
- lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn(core)
+ val sparkHome = buildLocation
+}
- lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core)
+object SparkBuild extends PomBuild {
- lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
- .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*)
+ import BuildCommons._
+ import scala.collection.mutable.Map
- lazy val assembleDepsTask = TaskKey[Unit]("assemble-deps")
- lazy val assembleDeps = assembleDepsTask := {
- println()
- println("**** NOTE ****")
- println("'sbt/sbt assemble-deps' is no longer supported.")
- println("Instead create a normal assembly and:")
- println(" export SPARK_PREPEND_CLASSES=1 (toggle on)")
- println(" unset SPARK_PREPEND_CLASSES (toggle off)")
- println()
- }
+ val projectsMap: Map[String, Seq[Setting[_]]] = Map.empty
- // A configuration to set an alternative publishLocalConfiguration
- lazy val MavenCompile = config("m2r") extend(Compile)
- lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
- val sparkHome = System.getProperty("user.dir")
-
- // Allows build configuration to be set through environment variables
- lazy val hadoopVersion = Properties.envOrElse("SPARK_HADOOP_VERSION", DEFAULT_HADOOP_VERSION)
- lazy val isNewHadoop = Properties.envOrNone("SPARK_IS_NEW_HADOOP") match {
- case None => {
- val isNewHadoopVersion = "^2\\.[2-9]+".r.findFirstIn(hadoopVersion).isDefined
- (isNewHadoopVersion|| DEFAULT_IS_NEW_HADOOP)
+ // Provides compatibility for older versions of the Spark build
+ def backwardCompatibility = {
+ import scala.collection.mutable
+ var isAlphaYarn = false
+ var profiles: mutable.Seq[String] = mutable.Seq.empty
+ if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) {
+ println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pganglia-lgpl flag.")
+ profiles ++= Seq("spark-ganglia-lgpl")
+ }
+ if (Properties.envOrNone("SPARK_HIVE").isDefined) {
+ println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.")
+ profiles ++= Seq("hive")
}
- case Some(v) => v.toBoolean
+ Properties.envOrNone("SPARK_HADOOP_VERSION") match {
+ case Some(v) =>
+ if (v.matches("0.23.*")) isAlphaYarn = true
+ println("NOTE: SPARK_HADOOP_VERSION is deprecated, please use -Dhadoop.version=" + v)
+ System.setProperty("hadoop.version", v)
+ case None =>
+ }
+ if (Properties.envOrNone("SPARK_YARN").isDefined) {
+ if(isAlphaYarn) {
+ println("NOTE: SPARK_YARN is deprecated, please use -Pyarn-alpha flag.")
+ profiles ++= Seq("yarn-alpha")
+ }
+ else {
+ println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.")
+ profiles ++= Seq("yarn")
+ }
+ }
+ profiles
}
- lazy val isYarnEnabled = Properties.envOrNone("SPARK_YARN") match {
- case None => DEFAULT_YARN
- case Some(v) => v.toBoolean
+ override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match {
+ case None => backwardCompatibility
+ case Some(v) =>
+ if (backwardCompatibility.nonEmpty)
+ println("Note: We ignore environment variables, when use of profile is detected in " +
+ "conjunction with environment variable.")
+ v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
}
- lazy val hadoopClient = if (hadoopVersion.startsWith("0.20.") || hadoopVersion == "1.0.0") "hadoop-core" else "hadoop-client"
- val maybeAvro = if (hadoopVersion.startsWith("0.23.")) Seq("org.apache.avro" % "avro" % "1.7.4") else Seq()
- lazy val isHiveEnabled = Properties.envOrNone("SPARK_HIVE") match {
- case None => DEFAULT_HIVE
- case Some(v) => v.toBoolean
+ Properties.envOrNone("SBT_MAVEN_PROPERTIES") match {
+ case Some(v) =>
+ v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1)))
+ case _ =>
}
- // Include Ganglia integration if the user has enabled Ganglia
- // This is isolated from the normal build due to LGPL-licensed code in the library
- lazy val isGangliaEnabled = Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined
- lazy val gangliaProj = Project("spark-ganglia-lgpl", file("extras/spark-ganglia-lgpl"), settings = gangliaSettings).dependsOn(core)
- val maybeGanglia: Seq[ClasspathDependency] = if (isGangliaEnabled) Seq(gangliaProj) else Seq()
- val maybeGangliaRef: Seq[ProjectReference] = if (isGangliaEnabled) Seq(gangliaProj) else Seq()
-
- // Include the Java 8 project if the JVM version is 8+
- lazy val javaVersion = System.getProperty("java.specification.version")
- lazy val isJava8Enabled = javaVersion.toDouble >= "1.8".toDouble
- val maybeJava8Tests = if (isJava8Enabled) Seq[ProjectReference](java8Tests) else Seq[ProjectReference]()
- lazy val java8Tests = Project("java8-tests", file("extras/java8-tests"), settings = java8TestsSettings).
- dependsOn(core) dependsOn(streaming % "compile->compile;test->test")
+ override val userPropertiesMap = System.getProperties.toMap
- // Include the YARN project if the user has enabled YARN
- lazy val yarnAlpha = Project("yarn-alpha", file("yarn/alpha"), settings = yarnAlphaSettings) dependsOn(core)
- lazy val yarn = Project("yarn", file("yarn/stable"), settings = yarnSettings) dependsOn(core)
-
- lazy val maybeYarn: Seq[ClasspathDependency] = if (isYarnEnabled) Seq(if (isNewHadoop) yarn else yarnAlpha) else Seq()
- lazy val maybeYarnRef: Seq[ProjectReference] = if (isYarnEnabled) Seq(if (isNewHadoop) yarn else yarnAlpha) else Seq()
-
- lazy val externalTwitter = Project("external-twitter", file("external/twitter"), settings = twitterSettings)
- .dependsOn(streaming % "compile->compile;test->test")
-
- lazy val externalKafka = Project("external-kafka", file("external/kafka"), settings = kafkaSettings)
- .dependsOn(streaming % "compile->compile;test->test")
-
- lazy val externalFlume = Project("external-flume", file("external/flume"), settings = flumeSettings)
- .dependsOn(streaming % "compile->compile;test->test")
-
- lazy val externalZeromq = Project("external-zeromq", file("external/zeromq"), settings = zeromqSettings)
- .dependsOn(streaming % "compile->compile;test->test")
-
- lazy val externalMqtt = Project("external-mqtt", file("external/mqtt"), settings = mqttSettings)
- .dependsOn(streaming % "compile->compile;test->test")
-
- lazy val allExternal = Seq[ClasspathDependency](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt)
- lazy val allExternalRefs = Seq[ProjectReference](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt)
-
- lazy val examples = Project("examples", file("examples"), settings = examplesSettings)
- .dependsOn(core, mllib, graphx, bagel, streaming, hive) dependsOn(allExternal: _*)
-
- // Everything except assembly, hive, tools, java8Tests and examples belong to packageProjects
- lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx, catalyst, sql) ++ maybeYarnRef ++ maybeHiveRef ++ maybeGangliaRef
-
- lazy val allProjects = packageProjects ++ allExternalRefs ++
- Seq[ProjectReference](examples, tools, assemblyProj) ++ maybeJava8Tests
-
- def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq(
- organization := "org.apache.spark",
- version := SPARK_VERSION,
- scalaVersion := "2.10.4",
- scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-feature",
- "-target:" + SCALAC_JVM_VERSION),
- javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION),
- unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath },
+ lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ Seq (
+ javaHome := Properties.envOrNone("JAVA_HOME").map(file),
+ incOptions := incOptions.value.withNameHashing(true),
retrieveManaged := true,
- javaHome := Properties.envOrNone("JAVA_HOME").map(file),
- // This is to add convenience of enabling sbt -Dsbt.offline=true for making the build offline.
- offline := "true".equalsIgnoreCase(sys.props("sbt.offline")),
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
- transitiveClassifiers in Scope.GlobalScope := Seq("sources"),
- testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))),
- incOptions := incOptions.value.withNameHashing(true),
- // Fork new JVMs for tests and set Java options for those
- fork := true,
- javaOptions in Test += "-Dspark.home=" + sparkHome,
- javaOptions in Test += "-Dspark.testing=1",
- javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
- javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark").map { case (k,v) => s"-D$k=$v" }.toSeq,
- javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g".split(" ").toSeq,
- javaOptions += "-Xmx3g",
- // Show full stack trace and duration in test cases.
- testOptions in Test += Tests.Argument("-oDF"),
- // Remove certain packages from Scaladoc
- scalacOptions in (Compile, doc) := Seq(
- "-groups",
- "-skip-packages", Seq(
- "akka",
- "org.apache.spark.api.python",
- "org.apache.spark.network",
- "org.apache.spark.deploy",
- "org.apache.spark.util.collection"
- ).mkString(":"),
- "-doc-title", "Spark " + SPARK_VERSION_SHORT + " ScalaDoc"
- ),
-
- // Only allow one test at a time, even across projects, since they run in the same JVM
- concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
-
- resolvers ++= Seq(
- // HTTPS is unavailable for Maven Central
- "Maven Repository" at "http://repo.maven.apache.org/maven2",
- "Apache Repository" at "https://repository.apache.org/content/repositories/releases",
- "JBoss Repository" at "https://repository.jboss.org/nexus/content/repositories/releases/",
- "MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/",
- "Cloudera Repository" at "http://repository.cloudera.com/artifactory/cloudera-repos/",
- "Pivotal Repository" at "http://repo.spring.io/libs-release/",
- // For Sonatype publishing
- // "sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
- // "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/",
- // also check the local Maven repository ~/.m2
- Resolver.mavenLocal
- ),
-
- publishMavenStyle := true,
-
- // useGpg in Global := true,
-
- pomExtra := (
-
- org.apache
- apache
- 14
-
- http://spark.apache.org/
-
-
- Apache 2.0 License
- http://www.apache.org/licenses/LICENSE-2.0.html
- repo
-
-
-
- scm:git:git@github.com:apache/spark.git
- scm:git:git@github.com:apache/spark.git
-
-
-
- matei
- Matei Zaharia
- matei.zaharia@gmail.com
- http://www.cs.berkeley.edu/~matei
- Apache Software Foundation
- http://spark.apache.org
-
-
-
- JIRA
- https://issues.apache.org/jira/browse/SPARK
-
- ),
-
- /*
- publishTo <<= version { (v: String) =>
- val nexus = "https://oss.sonatype.org/"
- if (v.trim.endsWith("SNAPSHOT"))
- Some("sonatype-snapshots" at nexus + "content/repositories/snapshots")
- else
- Some("sonatype-staging" at nexus + "service/local/staging/deploy/maven2")
- },
+ publishMavenStyle := true
+ )
- */
-
- libraryDependencies ++= Seq(
- "io.netty" % "netty-all" % "4.0.17.Final",
- "org.eclipse.jetty" % "jetty-server" % jettyVersion,
- "org.eclipse.jetty" % "jetty-util" % jettyVersion,
- "org.eclipse.jetty" % "jetty-plus" % jettyVersion,
- "org.eclipse.jetty" % "jetty-security" % jettyVersion,
- "org.scalatest" %% "scalatest" % "2.1.5" % "test",
- "org.scalacheck" %% "scalacheck" % "1.11.3" % "test",
- "com.novocode" % "junit-interface" % "0.10" % "test",
- "org.easymock" % "easymockclassextension" % "3.1" % "test",
- "org.mockito" % "mockito-all" % "1.9.0" % "test",
- "junit" % "junit" % "4.10" % "test",
- // Needed by cglib which is needed by easymock.
- "asm" % "asm" % "3.3.1" % "test"
- ),
+ /** Following project only exists to pull previous artifacts of Spark for generating
+ Mima ignores. For more information see: SPARK 2071 */
+ lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings)
- testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"),
- parallelExecution := true,
- /* Workaround for issue #206 (fixed after SBT 0.11.0) */
- watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
- const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) },
-
- otherResolvers := Seq(Resolver.file("dotM2", file(Path.userHome + "/.m2/repository"))),
- publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map {
- (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level)
- },
- publishMavenStyle in MavenCompile := true,
- publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
- publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
- ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings ++ ScalaStyleSettings ++ genjavadocSettings
-
- val akkaVersion = "2.2.3-shaded-protobuf"
- val chillVersion = "0.3.6"
- val codahaleMetricsVersion = "3.0.0"
- val jblasVersion = "1.2.3"
- val jets3tVersion = if ("^2\\.[3-9]+".r.findFirstIn(hadoopVersion).isDefined) "0.9.0" else "0.7.1"
- val jettyVersion = "8.1.14.v20131031"
- val hiveVersion = "0.12.0"
- val parquetVersion = "1.4.3"
- val slf4jVersion = "1.7.5"
-
- val excludeJBossNetty = ExclusionRule(organization = "org.jboss.netty")
- val excludeIONetty = ExclusionRule(organization = "io.netty")
- val excludeEclipseJetty = ExclusionRule(organization = "org.eclipse.jetty")
- val excludeAsm = ExclusionRule(organization = "org.ow2.asm")
- val excludeOldAsm = ExclusionRule(organization = "asm")
- val excludeCommonsLogging = ExclusionRule(organization = "commons-logging")
- val excludeSLF4J = ExclusionRule(organization = "org.slf4j")
- val excludeScalap = ExclusionRule(organization = "org.scala-lang", artifact = "scalap")
- val excludeHadoop = ExclusionRule(organization = "org.apache.hadoop")
- val excludeCurator = ExclusionRule(organization = "org.apache.curator")
- val excludePowermock = ExclusionRule(organization = "org.powermock")
- val excludeFastutil = ExclusionRule(organization = "it.unimi.dsi")
- val excludeJruby = ExclusionRule(organization = "org.jruby")
- val excludeThrift = ExclusionRule(organization = "org.apache.thrift")
- val excludeServletApi = ExclusionRule(organization = "javax.servlet", artifact = "servlet-api")
- val excludeJUnit = ExclusionRule(organization = "junit")
-
- def sparkPreviousArtifact(id: String, organization: String = "org.apache.spark",
- version: String = "1.0.0", crossVersion: String = "2.10"): Option[sbt.ModuleID] = {
- val fullId = if (crossVersion.isEmpty) id else id + "_" + crossVersion
- Some(organization % fullId % version) // the artifact to compare binary compatibility with
+ def versionArtifact(id: String): Option[sbt.ModuleID] = {
+ val fullId = id + "_2.10"
+ Some("org.apache.spark" % fullId % "1.0.0")
}
- def coreSettings = sharedSettings ++ Seq(
- name := "spark-core",
- libraryDependencies ++= Seq(
- "com.google.guava" % "guava" % "14.0.1",
- "org.apache.commons" % "commons-lang3" % "3.3.2",
- "org.apache.commons" % "commons-math3" % "3.3",
- "com.google.code.findbugs" % "jsr305" % "1.3.9",
- "log4j" % "log4j" % "1.2.17",
- "org.slf4j" % "slf4j-api" % slf4jVersion,
- "org.slf4j" % "slf4j-log4j12" % slf4jVersion,
- "org.slf4j" % "jul-to-slf4j" % slf4jVersion,
- "org.slf4j" % "jcl-over-slf4j" % slf4jVersion,
- "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407
- "com.ning" % "compress-lzf" % "1.0.0",
- "org.xerial.snappy" % "snappy-java" % "1.0.5",
- "org.spark-project.akka" %% "akka-remote" % akkaVersion,
- "org.spark-project.akka" %% "akka-slf4j" % akkaVersion,
- "org.spark-project.akka" %% "akka-testkit" % akkaVersion % "test",
- "org.json4s" %% "json4s-jackson" % "3.2.6" excludeAll(excludeScalap),
- "colt" % "colt" % "1.2.0",
- "org.apache.mesos" % "mesos" % "0.18.1" classifier("shaded-protobuf") exclude("com.google.protobuf", "protobuf-java"),
- "commons-net" % "commons-net" % "2.2",
- "net.java.dev.jets3t" % "jets3t" % jets3tVersion excludeAll(excludeCommonsLogging),
- "commons-codec" % "commons-codec" % "1.5", // Prevent jets3t from including the older version of commons-codec
- "org.apache.derby" % "derby" % "10.4.2.0" % "test",
- "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J, excludeOldAsm, excludeServletApi),
- "org.apache.curator" % "curator-recipes" % "2.4.0" excludeAll(excludeJBossNetty),
- "com.codahale.metrics" % "metrics-core" % codahaleMetricsVersion,
- "com.codahale.metrics" % "metrics-jvm" % codahaleMetricsVersion,
- "com.codahale.metrics" % "metrics-json" % codahaleMetricsVersion,
- "com.codahale.metrics" % "metrics-graphite" % codahaleMetricsVersion,
- "com.twitter" %% "chill" % chillVersion excludeAll(excludeAsm),
- "com.twitter" % "chill-java" % chillVersion excludeAll(excludeAsm),
- "org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock),
- "com.clearspring.analytics" % "stream" % "2.7.0" excludeAll(excludeFastutil), // Only HyperLogLogPlus is used, which does not depend on fastutil.
- "org.spark-project" % "pyrolite" % "2.0.1",
- "net.sf.py4j" % "py4j" % "0.8.1"
- ),
- libraryDependencies ++= maybeAvro,
- assembleDeps,
- previousArtifact := sparkPreviousArtifact("spark-core")
+ def oldDepsSettings() = Defaults.defaultSettings ++ Seq(
+ name := "old-deps",
+ scalaVersion := "2.10.4",
+ retrieveManaged := true,
+ retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
+ libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq",
+ "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter",
+ "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx",
+ "spark-core").map(versionArtifact(_).get intransitive())
)
- // Create a colon-separate package list adding "org.apache.spark" in front of all of them,
- // for easier specification of JavaDoc package groups
- def packageList(names: String*): String = {
- names.map(s => "org.apache.spark." + s).mkString(":")
+ def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
+ val existingSettings = projectsMap.getOrElse(projectRef.project, Seq[Setting[_]]())
+ projectsMap += (projectRef.project -> (existingSettings ++ settings))
}
- def rootSettings = sharedSettings ++ scalaJavaUnidocSettings ++ Seq(
- publish := {},
+ // Note ordering of these settings matter.
+ /* Enable shared settings on all projects */
+ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings))
- unidocProjectFilter in (ScalaUnidoc, unidoc) :=
- inAnyProject -- inProjects(repl, examples, tools, catalyst, yarn, yarnAlpha),
- unidocProjectFilter in (JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(repl, examples, bagel, graphx, catalyst, tools, yarn, yarnAlpha),
+ /* Enable tests settings for all projects except examples, assembly and tools */
+ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
- // Skip class names containing $ and some internal packages in Javadocs
- unidocAllSources in (JavaUnidoc, unidoc) := {
- (unidocAllSources in (JavaUnidoc, unidoc)).value
- .map(_.filterNot(_.getName.contains("$")))
- .map(_.filterNot(_.getCanonicalPath.contains("akka")))
- .map(_.filterNot(_.getCanonicalPath.contains("deploy")))
- .map(_.filterNot(_.getCanonicalPath.contains("network")))
- .map(_.filterNot(_.getCanonicalPath.contains("executor")))
- .map(_.filterNot(_.getCanonicalPath.contains("python")))
- .map(_.filterNot(_.getCanonicalPath.contains("collection")))
- },
+ /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */
+ // TODO: Add Sql to mima checks
+ allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)).
+ foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x))
- // Javadoc options: create a window title, and group key packages on index page
- javacOptions in doc := Seq(
- "-windowtitle", "Spark " + SPARK_VERSION_SHORT + " JavaDoc",
- "-public",
- "-group", "Core Java API", packageList("api.java", "api.java.function"),
- "-group", "Spark Streaming", packageList(
- "streaming.api.java", "streaming.flume", "streaming.kafka",
- "streaming.mqtt", "streaming.twitter", "streaming.zeromq"
- ),
- "-group", "MLlib", packageList(
- "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg",
- "mllib.linalg.distributed", "mllib.optimization", "mllib.rdd", "mllib.recommendation",
- "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration",
- "mllib.tree.impurity", "mllib.tree.model", "mllib.util"
- ),
- "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"),
- "-noqualifier", "java.lang"
- )
- )
+ /* Enable Assembly for all assembly projects */
+ assemblyProjects.foreach(enable(Assembly.settings))
- def replSettings = sharedSettings ++ Seq(
- name := "spark-repl",
- libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v),
- libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "jline" % v),
- libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v)
- )
+ /* Enable unidoc only for the root spark project */
+ enable(Unidoc.settings)(spark)
- def examplesSettings = sharedSettings ++ Seq(
- name := "spark-examples",
- jarName in assembly <<= version map {
- v => "spark-examples-" + v + "-hadoop" + hadoopVersion + ".jar" },
- libraryDependencies ++= Seq(
- "com.twitter" %% "algebird-core" % "0.1.11",
- "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeIONetty, excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging, excludeJruby),
- "org.apache.cassandra" % "cassandra-all" % "1.2.6"
- exclude("com.google.guava", "guava")
- exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru")
- exclude("com.ning","compress-lzf")
- exclude("io.netty", "netty")
- exclude("jline","jline")
- exclude("org.apache.cassandra.deps", "avro")
- excludeAll(excludeSLF4J, excludeIONetty),
- "com.github.scopt" %% "scopt" % "3.2.0"
- )
- ) ++ assemblySettings ++ extraAssemblySettings
-
- def toolsSettings = sharedSettings ++ Seq(
- name := "spark-tools",
- libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v),
- libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v )
- ) ++ assemblySettings ++ extraAssemblySettings
-
- def graphxSettings = sharedSettings ++ Seq(
- name := "spark-graphx",
- previousArtifact := sparkPreviousArtifact("spark-graphx"),
- libraryDependencies ++= Seq(
- "org.jblas" % "jblas" % jblasVersion
- )
- )
+ /* Hive console settings */
+ enable(Hive.settings)(hive)
- def bagelSettings = sharedSettings ++ Seq(
- name := "spark-bagel",
- previousArtifact := sparkPreviousArtifact("spark-bagel")
- )
+ // TODO: move this to its upstream project.
+ override def projectDefinitions(baseDirectory: File): Seq[Project] = {
+ super.projectDefinitions(baseDirectory).map { x =>
+ if (projectsMap.exists(_._1 == x.id)) x.settings(projectsMap(x.id): _*)
+ else x.settings(Seq[Setting[_]](): _*)
+ } ++ Seq[Project](oldDeps)
+ }
- def mllibSettings = sharedSettings ++ Seq(
- name := "spark-mllib",
- previousArtifact := sparkPreviousArtifact("spark-mllib"),
- libraryDependencies ++= Seq(
- "org.jblas" % "jblas" % jblasVersion,
- "org.scalanlp" %% "breeze" % "0.7" excludeAll(excludeJUnit)
- )
- )
+}
- def catalystSettings = sharedSettings ++ Seq(
- name := "catalyst",
- // The mechanics of rewriting expression ids to compare trees in some test cases makes
- // assumptions about the the expression ids being contiguous. Running tests in parallel breaks
- // this non-deterministically. TODO: FIX THIS.
- parallelExecution in Test := false,
- libraryDependencies ++= Seq(
- "com.typesafe" %% "scalalogging-slf4j" % "1.0.1"
- )
- )
+object Hive {
- def sqlCoreSettings = sharedSettings ++ Seq(
- name := "spark-sql",
- libraryDependencies ++= Seq(
- "com.twitter" % "parquet-column" % parquetVersion,
- "com.twitter" % "parquet-hadoop" % parquetVersion,
- "com.fasterxml.jackson.core" % "jackson-databind" % "2.3.0" // json4s-jackson 3.2.6 requires jackson-databind 2.3.0.
- ),
- initialCommands in console :=
- """
- |import org.apache.spark.sql.catalyst.analysis._
- |import org.apache.spark.sql.catalyst.dsl._
- |import org.apache.spark.sql.catalyst.errors._
- |import org.apache.spark.sql.catalyst.expressions._
- |import org.apache.spark.sql.catalyst.plans.logical._
- |import org.apache.spark.sql.catalyst.rules._
- |import org.apache.spark.sql.catalyst.types._
- |import org.apache.spark.sql.catalyst.util._
- |import org.apache.spark.sql.execution
- |import org.apache.spark.sql.test.TestSQLContext._
- |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin
- )
+ lazy val settings = Seq(
- // Since we don't include hive in the main assembly this project also acts as an alternative
- // assembly jar.
- def hiveSettings = sharedSettings ++ Seq(
- name := "spark-hive",
javaOptions += "-XX:MaxPermSize=1g",
- libraryDependencies ++= Seq(
- "org.spark-project.hive" % "hive-metastore" % hiveVersion,
- "org.spark-project.hive" % "hive-exec" % hiveVersion excludeAll(excludeCommonsLogging),
- "org.spark-project.hive" % "hive-serde" % hiveVersion
- ),
- // Multiple queries rely on the TestHive singleton. See comments there for more details.
+ // Multiple queries rely on the TestHive singleton. See comments there for more details.
parallelExecution in Test := false,
// Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
// only for this subproject.
@@ -555,67 +195,16 @@ object SparkBuild extends Build {
|import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin
)
- def streamingSettings = sharedSettings ++ Seq(
- name := "spark-streaming",
- previousArtifact := sparkPreviousArtifact("spark-streaming")
- )
-
- def yarnCommonSettings = sharedSettings ++ Seq(
- unmanagedSourceDirectories in Compile <++= baseDirectory { base =>
- Seq(
- base / "../common/src/main/scala"
- )
- },
-
- unmanagedSourceDirectories in Test <++= baseDirectory { base =>
- Seq(
- base / "../common/src/test/scala"
- )
- }
-
- ) ++ extraYarnSettings
-
- def yarnAlphaSettings = yarnCommonSettings ++ Seq(
- name := "spark-yarn-alpha"
- )
-
- def yarnSettings = yarnCommonSettings ++ Seq(
- name := "spark-yarn"
- )
-
- def gangliaSettings = sharedSettings ++ Seq(
- name := "spark-ganglia-lgpl",
- libraryDependencies += "com.codahale.metrics" % "metrics-ganglia" % "3.0.0"
- )
-
- def java8TestsSettings = sharedSettings ++ Seq(
- name := "java8-tests",
- javacOptions := Seq("-target", "1.8", "-source", "1.8"),
- testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a")
- )
-
- // Conditionally include the YARN dependencies because some tools look at all sub-projects and will complain
- // if we refer to nonexistent dependencies (e.g. hadoop-yarn-api from a Hadoop version without YARN).
- def extraYarnSettings = if(isYarnEnabled) yarnEnabledSettings else Seq()
-
- def yarnEnabledSettings = Seq(
- libraryDependencies ++= Seq(
- // Exclude rule required for all ?
- "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm),
- "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging),
- "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging),
- "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging),
- "org.apache.hadoop" % "hadoop-yarn-server-web-proxy" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging, excludeServletApi)
- )
- )
+}
- def assemblyProjSettings = sharedSettings ++ Seq(
- name := "spark-assembly",
- jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
- ) ++ assemblySettings ++ extraAssemblySettings
+object Assembly {
+ import sbtassembly.Plugin._
+ import AssemblyKeys._
- def extraAssemblySettings() = Seq(
+ lazy val settings = assemblySettings ++ Seq(
test in assembly := {},
+ jarName in assembly <<= (version, moduleName) map { (v, mName) => mName + "-"+v + "-hadoop" +
+ Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" },
mergeStrategy in assembly := {
case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
@@ -627,57 +216,95 @@ object SparkBuild extends Build {
}
)
- def oldDepsSettings() = Defaults.defaultSettings ++ Seq(
- name := "old-deps",
- scalaVersion := "2.10.4",
- retrieveManaged := true,
- retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
- libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq",
- "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter",
- "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx",
- "spark-core").map(sparkPreviousArtifact(_).get intransitive())
- )
+}
- def twitterSettings() = sharedSettings ++ Seq(
- name := "spark-streaming-twitter",
- previousArtifact := sparkPreviousArtifact("spark-streaming-twitter"),
- libraryDependencies ++= Seq(
- "org.twitter4j" % "twitter4j-stream" % "3.0.3"
- )
- )
+object Unidoc {
- def kafkaSettings() = sharedSettings ++ Seq(
- name := "spark-streaming-kafka",
- previousArtifact := sparkPreviousArtifact("spark-streaming-kafka"),
- libraryDependencies ++= Seq(
- "com.github.sgroschupf" % "zkclient" % "0.1",
- "org.apache.kafka" %% "kafka" % "0.8.0"
- exclude("com.sun.jdmk", "jmxtools")
- exclude("com.sun.jmx", "jmxri")
- exclude("net.sf.jopt-simple", "jopt-simple")
- excludeAll(excludeSLF4J)
- )
- )
+ import BuildCommons._
+ import sbtunidoc.Plugin._
+ import UnidocKeys._
+
+ // for easier specification of JavaDoc package groups
+ private def packageList(names: String*): String = {
+ names.map(s => "org.apache.spark." + s).mkString(":")
+ }
- def flumeSettings() = sharedSettings ++ Seq(
- name := "spark-streaming-flume",
- previousArtifact := sparkPreviousArtifact("spark-streaming-flume"),
- libraryDependencies ++= Seq(
- "org.apache.flume" % "flume-ng-sdk" % "1.4.0" % "compile" excludeAll(excludeIONetty, excludeThrift)
+ lazy val settings = scalaJavaUnidocSettings ++ Seq (
+ publish := {},
+
+ unidocProjectFilter in(ScalaUnidoc, unidoc) :=
+ inAnyProject -- inProjects(repl, examples, tools, catalyst, yarn, yarnAlpha),
+ unidocProjectFilter in(JavaUnidoc, unidoc) :=
+ inAnyProject -- inProjects(repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha),
+
+ // Skip class names containing $ and some internal packages in Javadocs
+ unidocAllSources in (JavaUnidoc, unidoc) := {
+ (unidocAllSources in (JavaUnidoc, unidoc)).value
+ .map(_.filterNot(_.getName.contains("$")))
+ .map(_.filterNot(_.getCanonicalPath.contains("akka")))
+ .map(_.filterNot(_.getCanonicalPath.contains("deploy")))
+ .map(_.filterNot(_.getCanonicalPath.contains("network")))
+ .map(_.filterNot(_.getCanonicalPath.contains("executor")))
+ .map(_.filterNot(_.getCanonicalPath.contains("python")))
+ .map(_.filterNot(_.getCanonicalPath.contains("collection")))
+ },
+
+ // Javadoc options: create a window title, and group key packages on index page
+ javacOptions in doc := Seq(
+ "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc",
+ "-public",
+ "-group", "Core Java API", packageList("api.java", "api.java.function"),
+ "-group", "Spark Streaming", packageList(
+ "streaming.api.java", "streaming.flume", "streaming.kafka",
+ "streaming.mqtt", "streaming.twitter", "streaming.zeromq"
+ ),
+ "-group", "MLlib", packageList(
+ "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg",
+ "mllib.linalg.distributed", "mllib.optimization", "mllib.rdd", "mllib.recommendation",
+ "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration",
+ "mllib.tree.impurity", "mllib.tree.model", "mllib.util"
+ ),
+ "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"),
+ "-noqualifier", "java.lang"
)
)
+}
- def zeromqSettings() = sharedSettings ++ Seq(
- name := "spark-streaming-zeromq",
- previousArtifact := sparkPreviousArtifact("spark-streaming-zeromq"),
- libraryDependencies ++= Seq(
- "org.spark-project.akka" %% "akka-zeromq" % akkaVersion
+object TestSettings {
+ import BuildCommons._
+
+ lazy val settings = Seq (
+ // Fork new JVMs for tests and set Java options for those
+ fork := true,
+ javaOptions in Test += "-Dspark.home=" + sparkHome,
+ javaOptions in Test += "-Dspark.testing=1",
+ javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
+ javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
+ .map { case (k,v) => s"-D$k=$v" }.toSeq,
+ javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
+ .split(" ").toSeq,
+ javaOptions += "-Xmx3g",
+
+ // Show full stack trace and duration in test cases.
+ testOptions in Test += Tests.Argument("-oDF"),
+ testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"),
+ // Enable Junit testing.
+ libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test",
+ // Only allow one test at a time, even across projects, since they run in the same JVM
+ parallelExecution in Test := false,
+ concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
+ // Remove certain packages from Scaladoc
+ scalacOptions in (Compile, doc) := Seq(
+ "-groups",
+ "-skip-packages", Seq(
+ "akka",
+ "org.apache.spark.api.python",
+ "org.apache.spark.network",
+ "org.apache.spark.deploy",
+ "org.apache.spark.util.collection"
+ ).mkString(":"),
+ "-doc-title", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " ScalaDoc"
)
)
- def mqttSettings() = streamingSettings ++ Seq(
- name := "spark-streaming-mqtt",
- previousArtifact := sparkPreviousArtifact("spark-streaming-mqtt"),
- libraryDependencies ++= Seq("org.eclipse.paho" % "mqtt-client" % "0.4.0")
- )
}
diff --git a/project/build.properties b/project/build.properties
index bcde13f4362a7..c12ef652adfcb 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-sbt.version=0.13.2
+sbt.version=0.13.5
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 472819b9fb8ba..d3ac4bf335e87 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -21,6 +21,6 @@ addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0")
addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
-addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.0")
+addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1")
addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.0")
diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala
index e9fba641eb8a1..3ef2d5451da0d 100644
--- a/project/project/SparkPluginBuild.scala
+++ b/project/project/SparkPluginBuild.scala
@@ -24,8 +24,10 @@ import sbt.Keys._
* becomes available for scalastyle sbt plugin.
*/
object SparkPluginDef extends Build {
- lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle)
+ lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader)
lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings)
+ lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git")
+
// There is actually no need to publish this artifact.
def styleSettings = Defaults.defaultSettings ++ Seq (
name := "spark-style",
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 0dbead4415b02..2a17127a7e0f9 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -56,7 +56,7 @@ def preexec_func():
(stdout, _) = proc.communicate()
exit_code = proc.poll()
error_msg = "Launching GatewayServer failed"
- error_msg += " with exit code %d!" % exit_code if exit_code else "! "
+ error_msg += " with exit code %d! " % exit_code if exit_code else "! "
error_msg += "(Warning: unexpected output detected.)\n\n"
error_msg += gateway_port + stdout
raise Exception(error_msg)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f64f48e3a4c9c..0c35c666805dd 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -69,16 +69,19 @@ def _extract_concise_traceback():
file, line, fun, what = tb[0]
return callsite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
- ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+ ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return callsite(function=sfun, file=ufile, linenum=uline)
_spark_stack_depth = 0
+
class _JavaStackTrace(object):
+
def __init__(self, sc):
tb = _extract_concise_traceback()
if tb is not None:
- self._traceback = "%s at %s:%s" % (tb.function, tb.file, tb.linenum)
+ self._traceback = "%s at %s:%s" % (
+ tb.function, tb.file, tb.linenum)
else:
self._traceback = "Error! Could not extract traceback info"
self._context = sc
@@ -95,7 +98,9 @@ def __exit__(self, type, value, tb):
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)
+
class MaxHeapQ(object):
+
"""
An implementation of MaxHeap.
>>> import pyspark.rdd
@@ -117,14 +122,14 @@ class MaxHeapQ(object):
"""
def __init__(self, maxsize):
- # we start from q[1], this makes calculating children as trivial as 2 * k
+ # We start from q[1], so its children are always 2 * k
self.q = [0]
self.maxsize = maxsize
def _swim(self, k):
- while (k > 1) and (self.q[k/2] < self.q[k]):
- self._swap(k, k/2)
- k = k/2
+ while (k > 1) and (self.q[k / 2] < self.q[k]):
+ self._swap(k, k / 2)
+ k = k / 2
def _swap(self, i, j):
t = self.q[i]
@@ -162,7 +167,9 @@ def _replaceRoot(self, value):
self.q[1] = value
self._sink(1)
+
class RDD(object):
+
"""
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
Represents an immutable, partitioned collection of elements that can be
@@ -257,7 +264,8 @@ def map(self, f, preservesPartitioning=False):
>>> sorted(rdd.map(lambda x: (x, 1)).collect())
[('a', 1), ('b', 1), ('c', 1)]
"""
- def func(split, iterator): return imap(f, iterator)
+ def func(split, iterator):
+ return imap(f, iterator)
return PipelinedRDD(self, func, preservesPartitioning)
def flatMap(self, f, preservesPartitioning=False):
@@ -271,7 +279,8 @@ def flatMap(self, f, preservesPartitioning=False):
>>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
- def func(s, iterator): return chain.from_iterable(imap(f, iterator))
+ def func(s, iterator):
+ return chain.from_iterable(imap(f, iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def mapPartitions(self, f, preservesPartitioning=False):
@@ -283,7 +292,8 @@ def mapPartitions(self, f, preservesPartitioning=False):
>>> rdd.mapPartitions(f).collect()
[3, 7]
"""
- def func(s, iterator): return f(iterator)
+ def func(s, iterator):
+ return f(iterator)
return self.mapPartitionsWithIndex(func)
def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
@@ -311,17 +321,17 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
6
"""
warnings.warn("mapPartitionsWithSplit is deprecated; "
- "use mapPartitionsWithIndex instead", DeprecationWarning, stacklevel=2)
+ "use mapPartitionsWithIndex instead", DeprecationWarning, stacklevel=2)
return self.mapPartitionsWithIndex(f, preservesPartitioning)
def getNumPartitions(self):
- """
- Returns the number of partitions in RDD
- >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
- >>> rdd.getNumPartitions()
- 2
- """
- return self._jrdd.partitions().size()
+ """
+ Returns the number of partitions in RDD
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> rdd.getNumPartitions()
+ 2
+ """
+ return self._jrdd.partitions().size()
def filter(self, f):
"""
@@ -331,7 +341,8 @@ def filter(self, f):
>>> rdd.filter(lambda x: x % 2 == 0).collect()
[2, 4]
"""
- def func(iterator): return ifilter(f, iterator)
+ def func(iterator):
+ return ifilter(f, iterator)
return self.mapPartitions(func)
def distinct(self):
@@ -391,9 +402,11 @@ def takeSample(self, withReplacement, num, seed=None):
maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint))
if num > maxSampleSize:
- raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)
+ raise ValueError(
+ "Sample size cannot be greater than %d." % maxSampleSize)
- fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement)
+ fraction = RDD._computeFractionForSampleSize(
+ num, initialCount, withReplacement)
samples = self.sample(withReplacement, fraction, seed).collect()
# If the first sample didn't turn out large enough, keep trying to take samples;
@@ -499,17 +512,17 @@ def __add__(self, other):
raise TypeError
return self.union(other)
- def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
+ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
"""
Sorts this RDD, which is assumed to consist of (key, value) pairs.
-
+ # noqa
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
>>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
>>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
- [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
+ [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5),...('white', 9), ('whose', 6)]
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
@@ -521,10 +534,12 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
# number of (key, value) pairs falling into them
if numPartitions > 1:
rddSize = self.count()
- maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
+ # constant from Spark's RangePartitioner
+ maxSampleSize = numPartitions * 20.0
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
- samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
+ samples = self.sample(False, fraction, 1).map(
+ lambda (k, v): k).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)
# we have numPartitions many parts but one of the them has
@@ -540,13 +555,13 @@ def rangePartitionFunc(k):
if ascending:
return p
else:
- return numPartitions-1-p
+ return numPartitions - 1 - p
def mapFunc(iterator):
yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
- .mapPartitions(mapFunc,preservesPartitioning=True)
+ .mapPartitions(mapFunc, preservesPartitioning=True)
.flatMap(lambda x: x, preservesPartitioning=True))
def sortBy(self, keyfunc, ascending=True, numPartitions=None):
@@ -570,7 +585,8 @@ def glom(self):
>>> sorted(rdd.glom().collect())
[[1, 2], [3, 4]]
"""
- def func(iterator): yield list(iterator)
+ def func(iterator):
+ yield list(iterator)
return self.mapPartitions(func)
def cartesian(self, other):
@@ -607,7 +623,9 @@ def pipe(self, command, env={}):
['1', '2', '', '3']
"""
def func(iterator):
- pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
+ pipe = Popen(
+ shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
+
def pipe_objs(out):
for obj in iterator:
out.write(str(obj).rstrip('\n') + '\n')
@@ -646,7 +664,7 @@ def collect(self):
Return a list that contains all of the elements in this RDD.
"""
with _JavaStackTrace(self.context) as st:
- bytesInJava = self._jrdd.collect().iterator()
+ bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
@@ -736,7 +754,6 @@ def func(iterator):
return self.mapPartitions(func).fold(zeroValue, combOp)
-
def max(self):
"""
Find the maximum item in this RDD.
@@ -844,6 +861,7 @@ def countPartition(iterator):
for obj in iterator:
counts[obj] += 1
yield counts
+
def mergeMaps(m1, m2):
for (k, v) in m2.iteritems():
m1[k] += v
@@ -888,22 +906,22 @@ def takeOrdered(self, num, key=None):
def topNKeyedElems(iterator, key_=None):
q = MaxHeapQ(num)
for k in iterator:
- if key_ != None:
+ if key_ is not None:
k = (key_(k), k)
q.insert(k)
yield q.getElements()
def unKey(x, key_=None):
- if key_ != None:
+ if key_ is not None:
x = [i[1] for i in x]
return x
def merge(a, b):
return next(topNKeyedElems(a + b))
- result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
+ result = self.mapPartitions(
+ lambda i: topNKeyedElems(i, key)).reduce(merge)
return sorted(unKey(result, key), key=key)
-
def take(self, num):
"""
Take the first num elements of the RDD.
@@ -947,7 +965,8 @@ def takeUpToNumLeft(iterator):
yield next(iterator)
taken += 1
- p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
+ p = range(
+ partsScanned, min(partsScanned + numPartsToTry, totalParts))
res = self.context.runJob(self, takeUpToNumLeft, p, True)
items += res
@@ -977,7 +996,7 @@ def saveAsPickleFile(self, path, batchSize=10):
[1, 2, 'rdd', 'spark']
"""
self._reserialize(BatchedSerializer(PickleSerializer(),
- batchSize))._jrdd.saveAsObjectFile(path)
+ batchSize))._jrdd.saveAsObjectFile(path)
def saveAsTextFile(self, path):
"""
@@ -1075,6 +1094,7 @@ def reducePartition(iterator):
for (k, v) in iterator:
m[k] = v if k not in m else func(m[k], v)
yield m
+
def mergeMaps(m1, m2):
for (k, v) in m2.iteritems():
m1[k] = v if k not in m1 else func(m1[k], v)
@@ -1162,6 +1182,7 @@ def partitionBy(self, numPartitions, partitionFunc=None):
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
outputSerializer = self.ctx._unbatched_serializer
+
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
@@ -1174,7 +1195,8 @@ def add_shuffle_key(split, iterator):
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
- pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ pairRDD = self.ctx._jvm.PairwiseRDD(
+ keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
@@ -1213,6 +1235,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+
def combineLocally(iterator):
combiners = {}
for x in iterator:
@@ -1224,10 +1247,11 @@ def combineLocally(iterator):
return combiners.iteritems()
locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numPartitions)
+
def _mergeCombiners(iterator):
combiners = {}
for (k, v) in iterator:
- if not k in combiners:
+ if k not in combiners:
combiners[k] = v
else:
combiners[k] = mergeCombiners(combiners[k], v)
@@ -1236,17 +1260,19 @@ def _mergeCombiners(iterator):
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
- Aggregate the values of each key, using given combine functions and a neutral "zero value".
- This function can return a different result type, U, than the type of the values in this RDD,
- V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
- The former operation is used for merging values within a partition, and the latter is used
- for merging values between partitions. To avoid memory allocation, both of these functions are
+ Aggregate the values of each key, using given combine functions and a neutral
+ "zero value". This function can return a different result type, U, than the type
+ of the values in this RDD, V. Thus, we need one operation for merging a V into
+ a U and one operation for merging two U's, The former operation is used for merging
+ values within a partition, and the latter is used for merging values between
+ partitions. To avoid memory allocation, both of these functions are
allowed to modify and return their first argument instead of creating a new U.
"""
def createZero():
- return copy.deepcopy(zeroValue)
+ return copy.deepcopy(zeroValue)
- return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
+ return self.combineByKey(
+ lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
def foldByKey(self, zeroValue, func, numPartitions=None):
"""
@@ -1261,11 +1287,10 @@ def foldByKey(self, zeroValue, func, numPartitions=None):
[('a', 2), ('b', 1)]
"""
def createZero():
- return copy.deepcopy(zeroValue)
+ return copy.deepcopy(zeroValue)
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
-
# TODO: support variant with custom partitioner
def groupByKey(self, numPartitions=None):
"""
@@ -1292,7 +1317,7 @@ def mergeCombiners(a, b):
return a + b
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
- numPartitions).mapValues(lambda x: ResultIterable(x))
+ numPartitions).mapValues(lambda x: ResultIterable(x))
# TODO: add tests
def flatMapValues(self, f):
@@ -1362,7 +1387,8 @@ def subtractByKey(self, other, numPartitions=None):
>>> sorted(x.subtractByKey(y).collect())
[('b', 4), ('b', 5)]
"""
- filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0
+ def filter_func((key, vals)):
+ return len(vals[0]) > 0 and len(vals[1]) == 0
map_func = lambda (key, vals): [(key, val) for val in vals[0]]
return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
@@ -1375,8 +1401,9 @@ def subtract(self, other, numPartitions=None):
>>> sorted(x.subtract(y).collect())
[('a', 1), ('b', 4), ('b', 5)]
"""
- rdd = other.map(lambda x: (x, True)) # note: here 'True' is just a placeholder
- return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) # note: here 'True' is just a placeholder
+ # note: here 'True' is just a placeholder
+ rdd = other.map(lambda x: (x, True))
+ return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0])
def keyBy(self, f):
"""
@@ -1434,7 +1461,7 @@ def zip(self, other):
"""
pairRDD = self._jrdd.zip(other._jrdd)
deserializer = PairDeserializer(self._jrdd_deserializer,
- other._jrdd_deserializer)
+ other._jrdd_deserializer)
return RDD(pairRDD, self.ctx, deserializer)
def name(self):
@@ -1503,7 +1530,9 @@ def _defaultReducePartitions(self):
# keys in the pairs. This could be an expensive operation, since those
# hashes aren't retained.
+
class PipelinedRDD(RDD):
+
"""
Pipelined maps:
>>> rdd = sc.parallelize([1, 2, 3, 4])
@@ -1519,6 +1548,7 @@ class PipelinedRDD(RDD):
>>> rdd.flatMap(lambda x: [x, x]).reduce(add)
20
"""
+
def __init__(self, prev, func, preservesPartitioning=False):
if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
# This transformation is the first in its stage:
@@ -1528,6 +1558,7 @@ def __init__(self, prev, func, preservesPartitioning=False):
self._prev_jrdd_deserializer = prev._jrdd_deserializer
else:
prev_func = prev.func
+
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
@@ -1560,11 +1591,13 @@ def _jrdd(self):
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
- self.ctx._gateway._gateway_client)
+ self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- bytearray(pickled_command), env, includes, self.preservesPartitioning,
- self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
- class_tag)
+ bytearray(pickled_command),
+ env, includes, self.preservesPartitioning,
+ self.ctx.pythonExec,
+ broadcast_vars, self.ctx._javaAccumulator,
+ class_tag)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
@@ -1579,7 +1612,8 @@ def _test():
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(
+ globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/repl/pom.xml b/repl/pom.xml
index 4a66408ef3d2d..4ebb1b82f0e8c 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -32,6 +32,7 @@
http://spark.apache.org/
+ repl/usr/share/sparkroot
diff --git a/sbt/sbt b/sbt/sbt
index 9de265bd07dcb..1b1aa1483a829 100755
--- a/sbt/sbt
+++ b/sbt/sbt
@@ -72,6 +72,7 @@ Usage: $script_name [options]
-J-X pass option -X directly to the java runtime
(-J is stripped)
-S-X add -X to sbt's scalacOptions (-J is stripped)
+ -PmavenProfiles Enable a maven profile for the build.
In the case of duplicated or conflicting options, the order above
shows precedence: JAVA_OPTS lowest, command line options highest.
diff --git a/sbt/sbt-launch-lib.bash b/sbt/sbt-launch-lib.bash
index 64e40a88206be..c91fecf024ad4 100755
--- a/sbt/sbt-launch-lib.bash
+++ b/sbt/sbt-launch-lib.bash
@@ -16,6 +16,7 @@ declare -a residual_args
declare -a java_args
declare -a scalac_args
declare -a sbt_commands
+declare -a maven_profiles
if test -x "$JAVA_HOME/bin/java"; then
echo -e "Using $JAVA_HOME as default JAVA_HOME."
@@ -87,6 +88,13 @@ addJava () {
dlog "[addJava] arg = '$1'"
java_args=( "${java_args[@]}" "$1" )
}
+
+enableProfile () {
+ dlog "[enableProfile] arg = '$1'"
+ maven_profiles=( "${maven_profiles[@]}" "$1" )
+ export SBT_MAVEN_PROFILES="${maven_profiles[@]}"
+}
+
addSbt () {
dlog "[addSbt] arg = '$1'"
sbt_commands=( "${sbt_commands[@]}" "$1" )
@@ -141,7 +149,8 @@ process_args () {
-java-home) require_arg path "$1" "$2" && java_cmd="$2/bin/java" && export JAVA_HOME=$2 && shift 2 ;;
-D*) addJava "$1" && shift ;;
- -J*) addJava "${1:2}" && shift ;;
+ -J*) addJava "${1:2}" && shift ;;
+ -P*) enableProfile "$1" && shift ;;
*) addResidual "$1" && shift ;;
esac
done
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 01d7b569080ea..6decde3fcd62d 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -31,6 +31,9 @@
jarSpark Project Catalysthttp://spark.apache.org/
+
+ catalyst
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
index a6ce90854dcb4..22941edef2d46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
* of itself with globally unique expression ids.
*/
trait MultiInstanceRelation {
- def newInstance: this.type
+ def newInstance(): this.type
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 347471cebdc7e..b3850533c3736 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.catalyst.types.StringType
import org.apache.spark.sql.catalyst.types.BooleanType
-
trait StringRegexExpression {
self: BinaryExpression =>
@@ -32,7 +31,7 @@ trait StringRegexExpression {
def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean
- def nullable: Boolean = true
+ def nullable: Boolean = left.nullable || right.nullable
def dataType: DataType = BooleanType
// try cache the pattern for Literal
@@ -157,19 +156,13 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
override def toString() = s"Lower($child)"
}
-/** A base class for functions that compare two strings, returning a boolean. */
-abstract class StringComparison extends Expression {
- self: Product =>
+/** A base trait for functions that compare two strings, returning a boolean. */
+trait StringComparison {
+ self: BinaryExpression =>
type EvaluatedType = Any
- def left: Expression
- def right: Expression
-
- override def references = children.flatMap(_.references).toSet
- override def children = left :: right :: Nil
-
- override def nullable: Boolean = true
+ def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
def compare(l: String, r: String): Boolean
@@ -184,26 +177,31 @@ abstract class StringComparison extends Expression {
}
}
+ def symbol: String = nodeName
+
override def toString() = s"$nodeName($left, $right)"
}
/**
* A function that returns true if the string `left` contains the string `right`.
*/
-case class Contains(left: Expression, right: Expression) extends StringComparison {
+case class Contains(left: Expression, right: Expression)
+ extends BinaryExpression with StringComparison {
override def compare(l: String, r: String) = l.contains(r)
}
/**
* A function that returns true if the string `left` starts with the string `right`.
*/
-case class StartsWith(left: Expression, right: Expression) extends StringComparison {
+case class StartsWith(left: Expression, right: Expression)
+ extends BinaryExpression with StringComparison {
def compare(l: String, r: String) = l.startsWith(r)
}
/**
* A function that returns true if the string `left` ends with the string `right`.
*/
-case class EndsWith(left: Expression, right: Expression) extends StringComparison {
+case class EndsWith(left: Expression, right: Expression)
+ extends BinaryExpression with StringComparison {
def compare(l: String, r: String) = l.endsWith(r)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f0904f59d028f..a142310c501b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -123,6 +123,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
val startsWith = "([^_%]+)%".r
val endsWith = "%([^_%]+)".r
val contains = "%([^_%]+)%".r
+ val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") =>
@@ -131,6 +132,8 @@ object LikeSimplification extends Rule[LogicalPlan] {
EndsWith(l, Literal(pattern))
case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") =>
Contains(l, Literal(pattern))
+ case Like(l, Literal(equalTo(pattern), StringType)) =>
+ EqualTo(l, Literal(pattern))
}
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 8210fd1f210d1..c309c43804d97 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -31,6 +31,9 @@
jarSpark Project SQLhttp://spark.apache.org/
+
+ sql
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index b6fb46a3acc03..2b787e14f3f15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -52,7 +52,7 @@ trait SQLConf {
/** ********************** SQLConf functionality methods ************ */
@transient
- protected[sql] val settings = java.util.Collections.synchronizedMap(
+ private val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
def set(props: Properties): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8bcfc7c064c2f..0c95b668545f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -256,6 +256,26 @@ class SchemaRDD(
def unionAll(otherPlan: SchemaRDD) =
new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan))
+ /**
+ * Performs a relational except on two SchemaRDDs
+ *
+ * @param otherPlan the [[SchemaRDD]] that should be excepted from this one.
+ *
+ * @group Query
+ */
+ def except(otherPlan: SchemaRDD): SchemaRDD =
+ new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan))
+
+ /**
+ * Performs a relational intersect on two SchemaRDDs
+ *
+ * @param otherPlan the [[SchemaRDD]] that should be intersected with this one.
+ *
+ * @group Query
+ */
+ def intersect(otherPlan: SchemaRDD): SchemaRDD =
+ new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan))
+
/**
* Filters tuples using a function over the value of the specified column.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index e1e4f24c6c66c..ff7f664d8b529 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.columnar
+import java.nio.ByteBuffer
+
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -26,22 +29,19 @@ import org.apache.spark.SparkConf
object InMemoryRelation {
def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, child)
+ new InMemoryRelation(child.output, useCompression, child)()
}
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
useCompression: Boolean,
child: SparkPlan)
+ (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null)
extends LogicalPlan with MultiInstanceRelation {
- override def children = Seq.empty
- override def references = Set.empty
-
- override def newInstance() =
- new InMemoryRelation(output.map(_.newInstance), useCompression, child).asInstanceOf[this.type]
-
- lazy val cachedColumnBuffers = {
+ // If the cached column buffers were not passed in, we calculate them in the constructor.
+ // As in Spark, the actual work of caching is lazy.
+ if (_cachedColumnBuffers == null) {
val output = child.output
val cached = child.execute().mapPartitions { iterator =>
val columnBuilders = output.map { attribute =>
@@ -62,10 +62,23 @@ private[sql] case class InMemoryRelation(
}.cache()
cached.setName(child.toString)
- // Force the materialization of the cached RDD.
- cached.count()
- cached
+ _cachedColumnBuffers = cached
}
+
+
+ override def children = Seq.empty
+
+ override def references = Set.empty
+
+ override def newInstance() = {
+ new InMemoryRelation(
+ output.map(_.newInstance),
+ useCompression,
+ child)(
+ _cachedColumnBuffers).asInstanceOf[this.type]
+ }
+
+ def cachedColumnBuffers = _cachedColumnBuffers
}
private[sql] case class InMemoryColumnarTableScan(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7080074a69c07..c078e71fe0290 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -247,8 +247,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Distinct(child) =>
- execution.Aggregate(
- partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
+ execution.Distinct(partial = false,
+ execution.Distinct(partial = true, planLater(child))) :: Nil
case logical.Sort(sortExprs, child) =>
// This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
execution.Sort(sortExprs, global = true, planLater(child)):: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 97abd636ab5fb..966d8f95fc83c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.util.MutablePair
/**
@@ -248,6 +248,37 @@ object ExistingRdd {
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}
+/**
+ * :: DeveloperApi ::
+ * Computes the set of distinct input rows using a HashSet.
+ * @param partial when true the distinct operation is performed partially, per partition, without
+ * shuffling the data.
+ * @param child the input query plan.
+ */
+@DeveloperApi
+case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode {
+ override def output = child.output
+
+ override def requiredChildDistribution =
+ if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil
+
+ override def execute() = {
+ child.execute().mapPartitions { iter =>
+ val hashSet = new scala.collection.mutable.HashSet[Row]()
+
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ if (!hashSet.contains(currentRow)) {
+ hashSet.add(currentRow.copy())
+ }
+ }
+
+ hashSet.iterator
+ }
+ }
+}
+
/**
* :: DeveloperApi ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 889a408e3c393..de8fe2dae38f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -114,7 +114,7 @@ private[sql] object CatalystConverter {
}
}
// All other primitive types use the default converter
- case ctype: NativeType => { // note: need the type tag here!
+ case ctype: PrimitiveType => { // note: need the type tag here!
new CatalystPrimitiveConverter(parent, fieldIndex)
}
case _ => throw new RuntimeException(
@@ -229,9 +229,9 @@ private[parquet] class CatalystGroupConverter(
this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null)
protected [parquet] val converters: Array[Converter] =
- schema.map(field =>
- CatalystConverter.createConverter(field, schema.indexOf(field), this))
- .toArray
+ schema.zipWithIndex.map {
+ case (field, idx) => CatalystConverter.createConverter(field, idx, this)
+ }.toArray
override val size = schema.size
@@ -288,9 +288,9 @@ private[parquet] class CatalystPrimitiveRowConverter(
new ParquetRelation.RowType(attributes.length))
protected [parquet] val converters: Array[Converter] =
- schema.map(field =>
- CatalystConverter.createConverter(field, schema.indexOf(field), this))
- .toArray
+ schema.zipWithIndex.map {
+ case (field, idx) => CatalystConverter.createConverter(field, idx, this)
+ }.toArray
override val size = schema.size
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 9cd5dc5bbd393..f1953a008a49b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -156,7 +156,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
writer.startMessage()
while(index < attributes.size) {
// null values indicate optional fields but we do not check currently
- if (record(index) != null && record(index) != Nil) {
+ if (record(index) != null) {
writer.startField(attributes(index).name, index)
writeValue(attributes(index).dataType, record(index))
writer.endField(attributes(index).name, index)
@@ -167,7 +167,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
}
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
- if (value != null && value != Nil) {
+ if (value != null) {
schema match {
case t @ ArrayType(_) => writeArray(
t,
@@ -184,13 +184,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
}
private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = {
- if (value != null && value != Nil) {
+ if (value != null) {
schema match {
case StringType => writer.addBinary(
Binary.fromByteArray(
value.asInstanceOf[String].getBytes("utf-8")
)
)
+ case BinaryType => writer.addBinary(
+ Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
case ShortType => writer.addInteger(value.asInstanceOf[Short])
case LongType => writer.addLong(value.asInstanceOf[Long])
@@ -206,12 +208,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
private[parquet] def writeStruct(
schema: StructType,
struct: CatalystConverter.StructScalaType[_]): Unit = {
- if (struct != null && struct != Nil) {
+ if (struct != null) {
val fields = schema.fields.toArray
writer.startGroup()
var i = 0
while(i < fields.size) {
- if (struct(i) != null && struct(i) != Nil) {
+ if (struct(i) != null) {
writer.startField(fields(i).name, i)
writeValue(fields(i).dataType, struct(i))
writer.endField(fields(i).name, i)
@@ -299,6 +301,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
record(index).asInstanceOf[String].getBytes("utf-8")
)
)
+ case BinaryType => writer.addBinary(
+ Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(record.getInt(index))
case ShortType => writer.addInteger(record.getShort(index))
case LongType => writer.addLong(record.getLong(index))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index 1dc58633a2a68..d4599da711254 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -58,7 +58,7 @@ private[sql] object ParquetTestData {
"""message myrecord {
optional boolean myboolean;
optional int32 myint;
- optional binary mystring;
+ optional binary mystring (UTF8);
optional int64 mylong;
optional float myfloat;
optional double mydouble;
@@ -87,7 +87,7 @@ private[sql] object ParquetTestData {
message myrecord {
required boolean myboolean;
required int32 myint;
- required binary mystring;
+ required binary mystring (UTF8);
required int64 mylong;
required float myfloat;
required double mydouble;
@@ -119,14 +119,14 @@ private[sql] object ParquetTestData {
// so that array types can be translated correctly.
"""
message AddressBook {
- required binary owner;
+ required binary owner (UTF8);
optional group ownerPhoneNumbers {
- repeated binary array;
+ repeated binary array (UTF8);
}
optional group contacts {
repeated group array {
- required binary name;
- optional binary phoneNumber;
+ required binary name (UTF8);
+ optional binary phoneNumber (UTF8);
}
}
}
@@ -181,16 +181,16 @@ private[sql] object ParquetTestData {
required int32 x;
optional group data1 {
repeated group map {
- required binary key;
+ required binary key (UTF8);
required int32 value;
}
}
required group data2 {
repeated group map {
- required binary key;
+ required binary key (UTF8);
required group value {
required int64 payload1;
- optional binary payload2;
+ optional binary payload2 (UTF8);
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index f9046368e7ced..7f6ad908f78ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -42,20 +42,22 @@ private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean =
classOf[PrimitiveType] isAssignableFrom ctype.getClass
- def toPrimitiveDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match {
- case ParquetPrimitiveTypeName.BINARY => StringType
- case ParquetPrimitiveTypeName.BOOLEAN => BooleanType
- case ParquetPrimitiveTypeName.DOUBLE => DoubleType
- case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType)
- case ParquetPrimitiveTypeName.FLOAT => FloatType
- case ParquetPrimitiveTypeName.INT32 => IntegerType
- case ParquetPrimitiveTypeName.INT64 => LongType
- case ParquetPrimitiveTypeName.INT96 =>
- // TODO: add BigInteger type? TODO(andre) use DecimalType instead????
- sys.error("Potential loss of precision: cannot convert INT96")
- case _ => sys.error(
- s"Unsupported parquet datatype $parquetType")
- }
+ def toPrimitiveDataType(parquetType: ParquetPrimitiveType): DataType =
+ parquetType.getPrimitiveTypeName match {
+ case ParquetPrimitiveTypeName.BINARY
+ if parquetType.getOriginalType == ParquetOriginalType.UTF8 => StringType
+ case ParquetPrimitiveTypeName.BINARY => BinaryType
+ case ParquetPrimitiveTypeName.BOOLEAN => BooleanType
+ case ParquetPrimitiveTypeName.DOUBLE => DoubleType
+ case ParquetPrimitiveTypeName.FLOAT => FloatType
+ case ParquetPrimitiveTypeName.INT32 => IntegerType
+ case ParquetPrimitiveTypeName.INT64 => LongType
+ case ParquetPrimitiveTypeName.INT96 =>
+ // TODO: add BigInteger type? TODO(andre) use DecimalType instead????
+ sys.error("Potential loss of precision: cannot convert INT96")
+ case _ => sys.error(
+ s"Unsupported parquet datatype $parquetType")
+ }
/**
* Converts a given Parquet `Type` into the corresponding
@@ -104,7 +106,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
}
if (parquetType.isPrimitive) {
- toPrimitiveDataType(parquetType.asPrimitiveType.getPrimitiveTypeName)
+ toPrimitiveDataType(parquetType.asPrimitiveType)
} else {
val groupType = parquetType.asGroupType()
parquetType.getOriginalType match {
@@ -164,18 +166,17 @@ private[parquet] object ParquetTypesConverter extends Logging {
* @return The name of the corresponding Parquet primitive type
*/
def fromPrimitiveDataType(ctype: DataType):
- Option[ParquetPrimitiveTypeName] = ctype match {
- case StringType => Some(ParquetPrimitiveTypeName.BINARY)
- case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN)
- case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE)
- case ArrayType(ByteType) =>
- Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)
- case FloatType => Some(ParquetPrimitiveTypeName.FLOAT)
- case IntegerType => Some(ParquetPrimitiveTypeName.INT32)
+ Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match {
+ case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))
+ case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None)
+ case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None)
+ case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None)
+ case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None)
+ case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None)
// There is no type for Byte or Short so we promote them to INT32.
- case ShortType => Some(ParquetPrimitiveTypeName.INT32)
- case ByteType => Some(ParquetPrimitiveTypeName.INT32)
- case LongType => Some(ParquetPrimitiveTypeName.INT64)
+ case ShortType => Some(ParquetPrimitiveTypeName.INT32, None)
+ case ByteType => Some(ParquetPrimitiveTypeName.INT32, None)
+ case LongType => Some(ParquetPrimitiveTypeName.INT64, None)
case _ => None
}
@@ -227,9 +228,10 @@ private[parquet] object ParquetTypesConverter extends Logging {
if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED
}
val primitiveType = fromPrimitiveDataType(ctype)
- if (primitiveType.isDefined) {
- new ParquetPrimitiveType(repetition, primitiveType.get, name)
- } else {
+ primitiveType.map {
+ case (primitiveType, originalType) =>
+ new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull)
+ }.getOrElse {
ctype match {
case ArrayType(elementType) => {
val parquetElementType = fromDataType(
@@ -237,7 +239,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME,
nullable = false,
inArray = true)
- ConversionPatterns.listType(repetition, name, parquetElementType)
+ ConversionPatterns.listType(repetition, name, parquetElementType)
}
case StructType(structFields) => {
val fields = structFields.map {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 04ac008682f5f..68dae58728a2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -168,4 +168,25 @@ class DslQuerySuite extends QueryTest {
test("zero count") {
assert(emptyTableData.count() === 0)
}
+
+ test("except") {
+ checkAnswer(
+ lowerCaseData.except(upperCaseData),
+ (1, "a") ::
+ (2, "b") ::
+ (3, "c") ::
+ (4, "d") :: Nil)
+ checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
+ checkAnswer(upperCaseData.except(upperCaseData), Nil)
+ }
+
+ test("intersect") {
+ checkAnswer(
+ lowerCaseData.intersect(lowerCaseData),
+ (1, "a") ::
+ (2, "b") ::
+ (3, "c") ::
+ (4, "d") :: Nil)
+ checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 054b14f8f7ffa..e17ecc87fd52a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -36,36 +36,6 @@ class JoinSuite extends QueryTest {
assert(planned.size === 1)
}
- test("plans broadcast hash join, given hints") {
-
- def mkTest(buildSide: BuildSide, leftTable: String, rightTable: String) = {
- TestSQLContext.settings.synchronized {
- TestSQLContext.set("spark.sql.join.broadcastTables",
- s"${if (buildSide == BuildRight) rightTable else leftTable}")
- val rdd = sql( s"""SELECT * FROM $leftTable JOIN $rightTable ON key = a""")
- // Using `sparkPlan` because for relevant patterns in HashJoin to be
- // matched, other strategies need to be applied.
- val physical = rdd.queryExecution.sparkPlan
- val bhj = physical.collect { case j: BroadcastHashJoin if j.buildSide == buildSide => j}
-
- assert(bhj.size === 1, "planner does not pick up hint to generate broadcast hash join")
- checkAnswer(
- rdd,
- Seq(
- (1, "1", 1, 1),
- (1, "1", 1, 2),
- (2, "2", 2, 1),
- (2, "2", 2, 2),
- (3, "3", 3, 1),
- (3, "3", 3, 2)
- ))
- }
- }
-
- mkTest(BuildRight, "testData", "testData2")
- mkTest(BuildLeft, "testData", "testData2")
- }
-
test("multiple-key equi-join is hash-join") {
val x = testData2.as('x)
val y = testData2.as('y)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 93792f698cfaf..08293f7f0ca30 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -28,50 +28,46 @@ class SQLConfSuite extends QueryTest {
val testVal = "test.val.0"
test("programmatic ways of basic setting and getting") {
- TestSQLContext.settings.synchronized {
- clear()
- assert(getOption(testKey).isEmpty)
- assert(getAll.toSet === Set())
+ clear()
+ assert(getOption(testKey).isEmpty)
+ assert(getAll.toSet === Set())
- set(testKey, testVal)
- assert(get(testKey) == testVal)
- assert(get(testKey, testVal + "_") == testVal)
- assert(getOption(testKey) == Some(testVal))
- assert(contains(testKey))
+ set(testKey, testVal)
+ assert(get(testKey) == testVal)
+ assert(get(testKey, testVal + "_") == testVal)
+ assert(getOption(testKey) == Some(testVal))
+ assert(contains(testKey))
- // Tests SQLConf as accessed from a SQLContext is mutable after
- // the latter is initialized, unlike SparkConf inside a SparkContext.
- assert(TestSQLContext.get(testKey) == testVal)
- assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
- assert(TestSQLContext.getOption(testKey) == Some(testVal))
- assert(TestSQLContext.contains(testKey))
+ // Tests SQLConf as accessed from a SQLContext is mutable after
+ // the latter is initialized, unlike SparkConf inside a SparkContext.
+ assert(TestSQLContext.get(testKey) == testVal)
+ assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
+ assert(TestSQLContext.getOption(testKey) == Some(testVal))
+ assert(TestSQLContext.contains(testKey))
- clear()
- }
+ clear()
}
test("parse SQL set commands") {
- TestSQLContext.settings.synchronized {
- clear()
- sql(s"set $testKey=$testVal")
- assert(get(testKey, testVal + "_") == testVal)
- assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
+ clear()
+ sql(s"set $testKey=$testVal")
+ assert(get(testKey, testVal + "_") == testVal)
+ assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
- sql("set mapred.reduce.tasks=20")
- assert(get("mapred.reduce.tasks", "0") == "20")
- sql("set mapred.reduce.tasks = 40")
- assert(get("mapred.reduce.tasks", "0") == "40")
+ sql("set mapred.reduce.tasks=20")
+ assert(get("mapred.reduce.tasks", "0") == "20")
+ sql("set mapred.reduce.tasks = 40")
+ assert(get("mapred.reduce.tasks", "0") == "40")
- val key = "spark.sql.key"
- val vs = "val0,val_1,val2.3,my_table"
- sql(s"set $key=$vs")
- assert(get(key, "0") == vs)
+ val key = "spark.sql.key"
+ val vs = "val0,val_1,val2.3,my_table"
+ sql(s"set $key=$vs")
+ assert(get(key, "0") == vs)
- sql(s"set $key=")
- assert(get(key, "0") == "")
+ sql(s"set $key=")
+ assert(get(key, "0") == "")
- clear()
- }
+ clear()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index fa1f32f8a49a9..0743cfe8cff0f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -397,40 +397,38 @@ class SQLQuerySuite extends QueryTest {
}
test("SET commands semantics using sql()") {
- TestSQLContext.settings.synchronized {
- clear()
- val testKey = "test.key.0"
- val testVal = "test.val.0"
- val nonexistentKey = "nonexistent"
-
- // "set" itself returns all config variables currently specified in SQLConf.
- assert(sql("SET").collect().size == 0)
-
- // "set key=val"
- sql(s"SET $testKey=$testVal")
- checkAnswer(
- sql("SET"),
- Seq(Seq(testKey, testVal))
- )
-
- sql(s"SET ${testKey + testKey}=${testVal + testVal}")
- checkAnswer(
- sql("set"),
- Seq(
- Seq(testKey, testVal),
- Seq(testKey + testKey, testVal + testVal))
- )
-
- // "set key"
- checkAnswer(
- sql(s"SET $testKey"),
- Seq(Seq(testKey, testVal))
- )
- checkAnswer(
- sql(s"SET $nonexistentKey"),
- Seq(Seq(nonexistentKey, ""))
- )
- clear()
- }
+ clear()
+ val testKey = "test.key.0"
+ val testVal = "test.val.0"
+ val nonexistentKey = "nonexistent"
+
+ // "set" itself returns all config variables currently specified in SQLConf.
+ assert(sql("SET").collect().size == 0)
+
+ // "set key=val"
+ sql(s"SET $testKey=$testVal")
+ checkAnswer(
+ sql("SET"),
+ Seq(Seq(testKey, testVal))
+ )
+
+ sql(s"SET ${testKey + testKey}=${testVal + testVal}")
+ checkAnswer(
+ sql("set"),
+ Seq(
+ Seq(testKey, testVal),
+ Seq(testKey + testKey, testVal + testVal))
+ )
+
+ // "set key"
+ checkAnswer(
+ sql(s"SET $testKey"),
+ Seq(Seq(testKey, testVal))
+ )
+ checkAnswer(
+ sql(s"SET $nonexistentKey"),
+ Seq(Seq(nonexistentKey, ""))
+ )
+ clear()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index dbf315947ff47..3c911e9a4e7b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -65,7 +65,8 @@ case class AllDataTypes(
doubleField: Double,
shortField: Short,
byteField: Byte,
- booleanField: Boolean)
+ booleanField: Boolean,
+ binaryField: Array[Byte])
case class AllDataTypesWithNonPrimitiveType(
stringField: String,
@@ -76,9 +77,10 @@ case class AllDataTypesWithNonPrimitiveType(
shortField: Short,
byteField: Byte,
booleanField: Boolean,
+ binaryField: Array[Byte],
array: Seq[Int],
map: Map[Int, String],
- nested: Nested)
+ data: Data)
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
TestData // Load test data tables.
@@ -116,7 +118,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
TestSQLContext.sparkContext.parallelize(range)
- .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0))
+ .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
+ (0 to x).map(_.toByte).toArray))
.saveAsParquetFile(tempDir)
val result = parquetFile(tempDir).collect()
range.foreach {
@@ -129,6 +132,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result(i).getShort(5) === i.toShort)
assert(result(i).getByte(6) === i.toByte)
assert(result(i).getBoolean(7) === (i % 2 == 0))
+ assert(result(i)(8) === (0 to i).map(_.toByte).toArray)
}
}
@@ -138,7 +142,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
TestSQLContext.sparkContext.parallelize(range)
.map(x => AllDataTypesWithNonPrimitiveType(
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
- Seq(x), Map(x -> s"$x"), Nested(x, s"$x")))
+ (0 to x).map(_.toByte).toArray,
+ (0 until x), (0 until x).map(i => i -> s"$i").toMap, Data((0 until x), Nested(x, s"$x"))))
.saveAsParquetFile(tempDir)
val result = parquetFile(tempDir).collect()
range.foreach {
@@ -151,9 +156,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result(i).getShort(5) === i.toShort)
assert(result(i).getByte(6) === i.toByte)
assert(result(i).getBoolean(7) === (i % 2 == 0))
- assert(result(i)(8) === Seq(i))
- assert(result(i)(9) === Map(i -> s"$i"))
- assert(result(i)(10) === new GenericRow(Array[Any](i, s"$i")))
+ assert(result(i)(8) === (0 to i).map(_.toByte).toArray)
+ assert(result(i)(9) === (0 until i))
+ assert(result(i)(10) === (0 until i).map(i => i -> s"$i").toMap)
+ assert(result(i)(11) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i")))))
}
}
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 5ede76e5c3904..f30ae28b81e06 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -31,6 +31,9 @@
jarSpark Project Hivehttp://spark.apache.org/
+
+ hive
+
@@ -48,6 +51,11 @@
hive-metastore${hive.version}
+
+ commons-httpclient
+ commons-httpclient
+ 3.1
+ org.spark-project.hivehive-exec
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 8cfde46186ca4..c3942578d6b5a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -164,13 +164,17 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
hivePartitionRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val rowWithPartArr = new Array[Object](2)
+
+ // The update and deserializer initialization are intentionally
+ // kept out of the below iter.map loop to save performance.
+ rowWithPartArr.update(1, partValues)
+ val deserializer = localDeserializer.newInstance()
+ deserializer.initialize(hconf, partProps)
+
// Map each tuple to a row object
iter.map { value =>
- val deserializer = localDeserializer.newInstance()
- deserializer.initialize(hconf, partProps)
val deserializedRow = deserializer.deserialize(value)
rowWithPartArr.update(0, deserializedRow)
- rowWithPartArr.update(1, partValues)
rowWithPartArr.asInstanceOf[Object]
}
}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index f506d6ce34a6f..f60697ce745b7 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -27,6 +27,9 @@
org.apache.sparkspark-streaming_2.10
+
+ streaming
+ jarSpark Project Streaminghttp://spark.apache.org/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
index 6376cff78b78a..ed7da6dc1315e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
@@ -41,7 +41,7 @@ class QueueInputDStream[T: ClassTag](
if (oneAtATime && queue.size > 0) {
buffer += queue.dequeue()
} else {
- buffer ++= queue
+ buffer ++= queue.dequeueAll(_ => true)
}
if (buffer.size > 0) {
if (oneAtATime) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index 78cc2daa56e53..0316b6862f195 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -44,7 +44,7 @@ private[streaming] class BlockGenerator(
listener: BlockGeneratorListener,
receiverId: Int,
conf: SparkConf
- ) extends Logging {
+ ) extends RateLimiter(conf) with Logging {
private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any])
@@ -81,6 +81,7 @@ private[streaming] class BlockGenerator(
* will be periodically pushed into BlockManager.
*/
def += (data: Any): Unit = synchronized {
+ waitToPush()
currentBuffer += data
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
new file mode 100644
index 0000000000000..e4f6ba626ebbf
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.receiver
+
+import org.apache.spark.{Logging, SparkConf}
+import java.util.concurrent.TimeUnit._
+
+/** Provides waitToPush() method to limit the rate at which receivers consume data.
+ *
+ * waitToPush method will block the thread if too many messages have been pushed too quickly,
+ * and only return when a new message has been pushed. It assumes that only one message is
+ * pushed at a time.
+ *
+ * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages
+ * per second that each receiver will accept.
+ *
+ * @param conf spark configuration
+ */
+private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging {
+
+ private var lastSyncTime = System.nanoTime
+ private var messagesWrittenSinceSync = 0L
+ private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0)
+ private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
+
+ def waitToPush() {
+ if( desiredRate <= 0 ) {
+ return
+ }
+ val now = System.nanoTime
+ val elapsedNanosecs = math.max(now - lastSyncTime, 1)
+ val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs
+ if (rate < desiredRate) {
+ // It's okay to write; just update some variables and return
+ messagesWrittenSinceSync += 1
+ if (now > lastSyncTime + SYNC_INTERVAL) {
+ // Sync interval has passed; let's resync
+ lastSyncTime = now
+ messagesWrittenSinceSync = 1
+ }
+ } else {
+ // Calculate how much time we should sleep to bring ourselves to the desired rate.
+ val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate
+ val elapsedTimeInMillis = elapsedNanosecs / 1000000
+ val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis
+ if (sleepTimeInMillis > 0) {
+ logTrace("Natural rate is " + rate + " per second but desired rate is " +
+ desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.")
+ Thread.sleep(sleepTimeInMillis)
+ }
+ waitToPush()
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index cd0aa4d0dce70..cc4a65011dd72 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -29,7 +29,7 @@ import java.nio.charset.Charset
import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue}
import com.google.common.io.Files
import org.scalatest.BeforeAndAfter
@@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.util.ManualClock
import org.apache.spark.util.Utils
import org.apache.spark.streaming.receiver.{ActorHelper, Receiver}
+import org.apache.spark.rdd.RDD
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
@@ -234,6 +235,95 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
logInfo("--------------------------------")
assert(output.sum === numTotalRecords)
}
+
+ test("queue input stream - oneAtATime=true") {
+ // Set up the streaming context and input streams
+ val ssc = new StreamingContext(conf, batchDuration)
+ val queue = new SynchronizedQueue[RDD[String]]()
+ val queueStream = ssc.queueStream(queue, oneAtATime = true)
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+ val outputStream = new TestOutputStream(queueStream, outputBuffer)
+ def output = outputBuffer.filter(_.size > 0)
+ outputStream.register()
+ ssc.start()
+
+ // Setup data queued into the stream
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq("1", "2", "3", "4", "5")
+ val expectedOutput = input.map(Seq(_))
+ //Thread.sleep(1000)
+ val inputIterator = input.toIterator
+ for (i <- 0 until input.size) {
+ // Enqueue more than 1 item per tick but they should dequeue one at a time
+ inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i)))
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(1000)
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Verify whether all the elements received are as expected
+ assert(output.size === expectedOutput.size)
+ for (i <- 0 until output.size) {
+ assert(output(i) === expectedOutput(i))
+ }
+ }
+
+ test("queue input stream - oneAtATime=false") {
+ // Set up the streaming context and input streams
+ val ssc = new StreamingContext(conf, batchDuration)
+ val queue = new SynchronizedQueue[RDD[String]]()
+ val queueStream = ssc.queueStream(queue, oneAtATime = false)
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+ val outputStream = new TestOutputStream(queueStream, outputBuffer)
+ def output = outputBuffer.filter(_.size > 0)
+ outputStream.register()
+ ssc.start()
+
+ // Setup data queued into the stream
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq("1", "2", "3", "4", "5")
+ val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5"))
+
+ // Enqueue the first 3 items (one by one), they should be merged in the next batch
+ val inputIterator = input.toIterator
+ inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i)))
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(1000)
+
+ // Enqueue the remaining items (again one by one), merged in the final batch
+ inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i)))
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(1000)
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Verify whether all the elements received are as expected
+ assert(output.size === expectedOutput.size)
+ for (i <- 0 until output.size) {
+ assert(output(i) === expectedOutput(i))
+ }
+ }
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
index d9ac3c91f6e36..f4e11f975de94 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
@@ -145,6 +145,44 @@ class NetworkReceiverSuite extends FunSuite with Timeouts {
assert(recordedData.toSet === generatedData.toSet)
}
+ test("block generator throttling") {
+ val blockGeneratorListener = new FakeBlockGeneratorListener
+ val blockInterval = 50
+ val maxRate = 200
+ val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString).
+ set("spark.streaming.receiver.maxRate", maxRate.toString)
+ val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf)
+ val expectedBlocks = 20
+ val waitTime = expectedBlocks * blockInterval
+ val expectedMessages = maxRate * waitTime / 1000
+ val expectedMessagesPerBlock = maxRate * blockInterval / 1000
+ val generatedData = new ArrayBuffer[Int]
+
+ // Generate blocks
+ val startTime = System.currentTimeMillis()
+ blockGenerator.start()
+ var count = 0
+ while(System.currentTimeMillis - startTime < waitTime) {
+ blockGenerator += count
+ generatedData += count
+ count += 1
+ Thread.sleep(1)
+ }
+ blockGenerator.stop()
+
+ val recordedData = blockGeneratorListener.arrayBuffers
+ assert(blockGeneratorListener.arrayBuffers.size > 0)
+ assert(recordedData.flatten.toSet === generatedData.toSet)
+ // recordedData size should be close to the expected rate
+ assert(recordedData.flatten.size >= expectedMessages * 0.9 &&
+ recordedData.flatten.size <= expectedMessages * 1.1 )
+ // the first and last block may be incomplete, so we slice them out
+ recordedData.slice(1, recordedData.size - 1).foreach { block =>
+ assert(block.size >= expectedMessagesPerBlock * 0.8 &&
+ block.size <= expectedMessagesPerBlock * 1.2 )
+ }
+ }
+
/**
* An implementation of NetworkReceiver that is used for testing a receiver's life cycle.
*/
diff --git a/tools/pom.xml b/tools/pom.xml
index 79cd8551d0722..c0ee8faa7a615 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -26,6 +26,9 @@
org.apache.sparkspark-tools_2.10
+
+ tools
+ jarSpark Project Toolshttp://spark.apache.org/
diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml
index b8a631dd0bb3b..5b13a1f002d6e 100644
--- a/yarn/alpha/pom.xml
+++ b/yarn/alpha/pom.xml
@@ -23,6 +23,9 @@
1.1.0-SNAPSHOT../pom.xml
+
+ yarn-alpha
+ org.apache.sparkspark-yarn-alpha_2.10
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 438737f7a6b60..062f946a9fe93 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -184,6 +184,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
private def startUserClass(): Thread = {
logInfo("Starting the user JAR in a separate Thread")
+ System.setProperty("spark.executor.instances", args.numExecutors.toString)
val mainMethod = Class.forName(
args.userClass,
false /* initialize */ ,
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
index 25cc9016b10a6..4c383ab574abe 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -26,7 +26,7 @@ class ApplicationMasterArguments(val args: Array[String]) {
var userArgs: Seq[String] = Seq[String]()
var executorMemory = 1024
var executorCores = 1
- var numExecutors = 2
+ var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS
parseArgs(args.toList)
@@ -93,3 +93,7 @@ class ApplicationMasterArguments(val args: Array[String]) {
System.exit(exitCode)
}
}
+
+object ApplicationMasterArguments {
+ val DEFAULT_NUMBER_EXECUTORS = 2
+}
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
index 6b91e6b9eb899..15e8c21aa5906 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -40,8 +40,10 @@ private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configur
override def postStartHook() {
+ super.postStartHook()
// The yarn application is running, but the executor might not yet ready
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ // TODO It needn't after waitBackendReady
Thread.sleep(2000L)
logInfo("YarnClientClusterScheduler.postStartHook done")
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index fd2694fe7278d..0f9fdcfcb6510 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -75,6 +75,7 @@ private[spark] class YarnClientSchedulerBackend(
logDebug("ClientArguments called with: " + argsArrayBuf)
val args = new ClientArguments(argsArrayBuf.toArray, conf)
+ totalExpectedExecutors.set(args.numExecutors)
client = new Client(args, conf)
appId = client.runApp()
waitForApp()
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 39cdd2e8a522b..9ee53d797c8ea 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -48,9 +48,11 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
override def postStartHook() {
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
+ super.postStartHook()
if (sparkContextInitialized){
ApplicationMaster.waitForInitialAllocations()
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ // TODO It needn't after waitBackendReady
Thread.sleep(3000L)
}
logInfo("YarnClusterScheduler.postStartHook done")
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
new file mode 100644
index 0000000000000..a04b08f43cc5a
--- /dev/null
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.SparkContext
+import org.apache.spark.deploy.yarn.ApplicationMasterArguments
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.IntParam
+
+private[spark] class YarnClusterSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) {
+
+ override def start() {
+ super.start()
+ var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS
+ if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) {
+ numExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")).getOrElse(numExecutors)
+ }
+ // System property can override environment variable.
+ numExecutors = sc.getConf.getInt("spark.executor.instances", numExecutors)
+ totalExpectedExecutors.set(numExecutors)
+ }
+}
diff --git a/yarn/pom.xml b/yarn/pom.xml
index ef7066ef1fdfc..efb473aa1b261 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -28,6 +28,9 @@
yarn-parent_2.10pomSpark Project YARN Parent POM
+
+ yarn
+
diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml
index 0931beb505508..ceaf9f9d71001 100644
--- a/yarn/stable/pom.xml
+++ b/yarn/stable/pom.xml
@@ -23,6 +23,9 @@
1.1.0-SNAPSHOT../pom.xml
+
+ yarn-stable
+ org.apache.sparkspark-yarn_2.10
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index ee1e9c9c23d22..1a24ec759b546 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -164,6 +164,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
private def startUserClass(): Thread = {
logInfo("Starting the user JAR in a separate Thread")
+ System.setProperty("spark.executor.instances", args.numExecutors.toString)
val mainMethod = Class.forName(
args.userClass,
false,