", actorSystem, master, serializer, 1200, conf,
- securityMgr, mapOutputTracker, shuffleManager)
-
- val worker = spy(new BlockManagerWorker(store))
- val connManagerId = mock(classOf[ConnectionManagerId])
-
- // setup request block messages
- val reqBlId1 = ShuffleBlockId(0,0,0)
- val reqBlId2 = ShuffleBlockId(0,1,0)
- val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
- val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
- val reqBlockMessages = new BlockMessageArray(
- Seq(reqBlockMessage1, reqBlockMessage2))
-
- val tmpBufferMessage = reqBlockMessages.toBufferMessage
- val buffer = ByteBuffer.allocate(tmpBufferMessage.size)
- val arrayBuffer = new ArrayBuffer[ByteBuffer]
- tmpBufferMessage.buffers.foreach{ b =>
- buffer.put(b)
- }
- buffer.flip()
- arrayBuffer += buffer
- val reqBufferMessage = Message.createBufferMessage(arrayBuffer)
-
- // setup ack block messages
- val buf1 = ByteBuffer.allocate(4)
- val buf2 = ByteBuffer.allocate(4)
- buf1.putInt(1)
- buf1.flip()
- buf2.putInt(1)
- buf2.flip()
- val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1))
- val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2))
-
- val answer = new Answer[Option[BlockMessage]] {
- override def answer(invocation: InvocationOnMock)
- :Option[BlockMessage]= {
- if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq(
- reqBlockMessage1)) {
- return Some(ackBlockMessage1)
- } else {
- return Some(ackBlockMessage2)
- }
- }
- }
-
- doAnswer(answer).when(worker).processBlockMessage(any())
-
- val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
- assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " +
- "was executed successfully, ackMessage should be defined")
- assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " +
- "was executed successfully, ackMessage should not have error")
- }
-
test("reserve/release unroll memory") {
store = makeBlockManager(12000)
val memoryStore = store.memoryStore
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 26082ded8ca7a..e4522e00a622d 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import java.io.{File, FileWriter}
+import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.shuffle.hash.HashShuffleManager
import scala.collection.mutable
@@ -52,7 +53,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
rootDir1 = Files.createTempDir()
rootDir1.deleteOnExit()
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
- println("Created root dirs: " + rootDirs)
}
override def afterAll() {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
new file mode 100644
index 0000000000000..809bd70929656
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.TaskContext
+import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+
+import org.mockito.Mockito._
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.scalatest.FunSuite
+
+
+class ShuffleBlockFetcherIteratorSuite extends FunSuite {
+
+ test("handle local read failures in BlockManager") {
+ val transfer = mock(classOf[BlockTransferService])
+ val blockManager = mock(classOf[BlockManager])
+ doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
+
+ val blIds = Array[BlockId](
+ ShuffleBlockId(0,0,0),
+ ShuffleBlockId(0,1,0),
+ ShuffleBlockId(0,2,0),
+ ShuffleBlockId(0,3,0),
+ ShuffleBlockId(0,4,0))
+
+ val optItr = mock(classOf[Option[Iterator[Any]]])
+ val answer = new Answer[Option[Iterator[Any]]] {
+ override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
+ throw new Exception
+ }
+ }
+
+ // 3rd block is going to fail
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+ doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+
+ val bmId = BlockManagerId("test-client", "test-client", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ )
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ // Without exhausting the iterator, the iterator should be lazy and not call
+ // getLocalShuffleFromDisk.
+ verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+ // the 2nd element of the tuple returned by iterator.next should be defined when
+ // fetching successfully
+ assert(iterator.next()._2.isDefined,
+ "1st element should be defined but is not actually defined")
+ verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+ assert(iterator.next()._2.isDefined,
+ "2nd element should be defined but is not actually defined")
+ verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+ // 3rd fetch should be failed
+ intercept[Exception] {
+ iterator.next()
+ }
+ verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
+ }
+
+ test("handle local read successes") {
+ val transfer = mock(classOf[BlockTransferService])
+ val blockManager = mock(classOf[BlockManager])
+ doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
+
+ val blIds = Array[BlockId](
+ ShuffleBlockId(0,0,0),
+ ShuffleBlockId(0,1,0),
+ ShuffleBlockId(0,2,0),
+ ShuffleBlockId(0,3,0),
+ ShuffleBlockId(0,4,0))
+
+ val optItr = mock(classOf[Option[Iterator[Any]]])
+
+ // All blocks should be fetched successfully
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+ doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+
+ val bmId = BlockManagerId("test-client", "test-client", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ )
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
+ verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
+
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 1st element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 2nd element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 3rd element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 4th element is not actually defined")
+ assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
+ assert(iterator.next()._2.isDefined,
+ "All elements should be defined but 5th element is not actually defined")
+
+ verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
+ }
+
+ test("handle remote fetch failures in BlockTransferService") {
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ listener.onBlockFetchFailure(new Exception("blah"))
+ }
+ })
+
+ val blockManager = mock(classOf[BlockManager])
+
+ when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1))
+
+ val blId1 = ShuffleBlockId(0, 0, 0)
+ val blId2 = ShuffleBlockId(0, 1, 0)
+ val bmId = BlockManagerId("test-server", "test-server", 1)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, Seq((blId1, 1L), (blId2, 1L))))
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ new TaskContext(0, 0, 0),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ null,
+ 48 * 1024 * 1024)
+
+ iterator.foreach { case (_, iterOption) =>
+ assert(!iterOption.isDefined)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index ac3931e3d0a73..511d76c9144cc 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -42,6 +42,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
conf.set("spark.serializer.objectStreamReset", "1")
conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
conf.set("spark.shuffle.spill.compress", codec.isDefined.toString)
+ conf.set("spark.shuffle.compress", codec.isDefined.toString)
codec.foreach { c => conf.set("spark.io.compression.codec", c) }
// Ensure that we actually have multiple batches per spill file
conf.set("spark.shuffle.spill.batchSize", "10")
diff --git a/dev/check-license b/dev/check-license
index 625ec161bc571..9ff0929e9a5e8 100755
--- a/dev/check-license
+++ b/dev/check-license
@@ -23,18 +23,18 @@ acquire_rat_jar () {
URL1="http://search.maven.org/remotecontent?filepath=org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar"
URL2="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar"
- JAR=$rat_jar
+ JAR="$rat_jar"
if [[ ! -f "$rat_jar" ]]; then
# Download rat launch jar if it hasn't been downloaded yet
if [ ! -f "$JAR" ]; then
# Download
printf "Attempting to fetch rat\n"
- JAR_DL=${JAR}.part
+ JAR_DL="${JAR}.part"
if hash curl 2>/dev/null; then
- (curl --progress-bar ${URL1} > "$JAR_DL" || curl --progress-bar ${URL2} > "$JAR_DL") && mv "$JAR_DL" "$JAR"
+ (curl --silent "${URL1}" > "$JAR_DL" || curl --silent "${URL2}" > "$JAR_DL") && mv "$JAR_DL" "$JAR"
elif hash wget 2>/dev/null; then
- (wget --progress=bar ${URL1} -O "$JAR_DL" || wget --progress=bar ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR"
+ (wget --quiet ${URL1} -O "$JAR_DL" || wget --quiet ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR"
else
printf "You do not have curl or wget installed, please install rat manually.\n"
exit -1
@@ -50,7 +50,7 @@ acquire_rat_jar () {
}
# Go to the Spark project root directory
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
if test -x "$JAVA_HOME/bin/java"; then
@@ -60,17 +60,17 @@ else
fi
export RAT_VERSION=0.10
-export rat_jar=$FWDIR/lib/apache-rat-${RAT_VERSION}.jar
-mkdir -p $FWDIR/lib
+export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar
+mkdir -p "$FWDIR"/lib
[[ -f "$rat_jar" ]] || acquire_rat_jar || {
echo "Download failed. Obtain the rat jar manually and place it at $rat_jar"
exit 1
}
-$java_cmd -jar $rat_jar -E $FWDIR/.rat-excludes -d $FWDIR > rat-results.txt
+$java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt
-ERRORS=$(cat rat-results.txt | grep -e "??")
+ERRORS="$(cat rat-results.txt | grep -e "??")"
if test ! -z "$ERRORS"; then
echo "Could not find Apache license headers in the following files:"
diff --git a/dev/lint-python b/dev/lint-python
index a1e890faa8fa6..772f856154ae0 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -18,10 +18,10 @@
#
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
-SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)"
+SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")"
PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt"
-cd $SPARK_ROOT_DIR
+cd "$SPARK_ROOT_DIR"
# Get pep8 at runtime so that we don't rely on it being installed on the build server.
#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162
@@ -30,6 +30,7 @@ cd $SPARK_ROOT_DIR
#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?))
PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py"
PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py"
+PEP8_PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/"
curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH"
curl_status=$?
@@ -44,7 +45,7 @@ fi
#+ first, but we do so so that the check status can
#+ be output before the report, like with the
#+ scalastyle and RAT checks.
-python $PEP8_SCRIPT_PATH ./python/pyspark > "$PEP8_REPORT_PATH"
+python "$PEP8_SCRIPT_PATH" $PEP8_PATHS_TO_CHECK > "$PEP8_REPORT_PATH"
pep8_status=${PIPESTATUS[0]} #$?
if [ $pep8_status -ne 0 ]; then
@@ -54,7 +55,7 @@ else
echo "PEP 8 checks passed."
fi
-rm -f "$PEP8_REPORT_PATH"
+rm "$PEP8_REPORT_PATH"
rm "$PEP8_SCRIPT_PATH"
exit $pep8_status
diff --git a/dev/mima b/dev/mima
index 09e4482af5f3d..f9b9b03538f15 100755
--- a/dev/mima
+++ b/dev/mima
@@ -21,12 +21,12 @@ set -o pipefail
set -e
# Go to the Spark project root directory
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
echo -e "q\n" | sbt/sbt oldDeps/update
-export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`
+export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`"
echo "SPARK_CLASSPATH=$SPARK_CLASSPATH"
./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
diff --git a/dev/run-tests b/dev/run-tests
index 90a8ce16f0f06..79401213a7fa2 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -18,7 +18,7 @@
#
# Go to the Spark project root directory
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then
@@ -93,7 +93,7 @@ echo "========================================================================="
# echo "q" is needed because sbt on encountering a build file with failure
# (either resolution or compilation) prompts the user for input either q, r,
# etc to quit or retry. This echo is there to make it not block.
-BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver "
+BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive "
echo -e "q\n" | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly | \
grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
diff --git a/dev/scalastyle b/dev/scalastyle
index eb9b467965636..efb5f291ea3b7 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -19,7 +19,7 @@
echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
# Check style with YARN alpha built too
-echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
+echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
>> scalastyle.txt
# Check style with YARN built too
echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \
diff --git a/docs/configuration.md b/docs/configuration.md
index 65a422caabb7e..36178efb97103 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -293,12 +293,11 @@ Apart from these, the following properties are also available, and may be useful
spark.shuffle.manager |
- HASH |
+ sort |
- Implementation to use for shuffling data. A hash-based shuffle manager is the default, but
- starting in Spark 1.1 there is an experimental sort-based shuffle manager that is more
- memory-efficient in environments with small executors, such as YARN. To use that, change
- this value to SORT .
+ Implementation to use for shuffling data. There are two implementations available:
+ sort and hash . Sort-based shuffle is more memory-efficient and is
+ the default option starting in 1.2.
|
diff --git a/docs/img/streaming-arch.png b/docs/img/streaming-arch.png
index bc57b460fdf8b..ac35f1d34cf3d 100644
Binary files a/docs/img/streaming-arch.png and b/docs/img/streaming-arch.png differ
diff --git a/docs/img/streaming-figures.pptx b/docs/img/streaming-figures.pptx
index 1b18c2ee0ea3e..d1cc25e379f46 100644
Binary files a/docs/img/streaming-figures.pptx and b/docs/img/streaming-figures.pptx differ
diff --git a/docs/img/streaming-kinesis-arch.png b/docs/img/streaming-kinesis-arch.png
new file mode 100644
index 0000000000000..bea5fa88df985
Binary files /dev/null and b/docs/img/streaming-kinesis-arch.png differ
diff --git a/docs/index.md b/docs/index.md
index 4ac0982ae54f1..7fe6b43d32af7 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -103,6 +103,8 @@ options for deployment:
* [Security](security.html): Spark security support
* [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware
* [3rd Party Hadoop Distributions](hadoop-third-party-distributions.html): using common Hadoop distributions
+* Integration with other storage systems:
+ * [OpenStack Swift](storage-openstack-swift.html)
* [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system
* [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark)
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 1166d9cd150c4..12a6afbeea829 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -80,7 +80,7 @@ The ordered splits create "bins" and the maximum number of such
bins can be specified using the `maxBins` parameter.
Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario
-since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of
+since the default `maxBins` value is 32). The tree algorithm automatically reduces the number of
bins if the condition is not satisfied.
**Categorical features**
@@ -117,7 +117,7 @@ all nodes at each level of the tree. This could lead to high memory requirements
of the tree, potentially leading to memory overflow errors. To alleviate this problem, a `maxMemoryInMB`
training parameter specifies the maximum amount of memory at the workers (twice as much at the
master) to be allocated to the histogram computation. The default value is conservatively chosen to
-be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
+be 256 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements
for a level-wise computation cross the `maxMemoryInMB` threshold, the node training tasks at each
subsequent level are split into smaller tasks.
@@ -167,7 +167,7 @@ val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
-val maxBins = 100
+val maxBins = 32
val model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity,
maxDepth, maxBins)
@@ -213,7 +213,7 @@ Integer numClasses = 2;
HashMap categoricalFeaturesInfo = new HashMap();
String impurity = "gini";
Integer maxDepth = 5;
-Integer maxBins = 100;
+Integer maxBins = 32;
// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
@@ -250,7 +250,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(data, numClasses=2, categoricalFeaturesInfo={},
- impurity='gini', maxDepth=5, maxBins=100)
+ impurity='gini', maxDepth=5, maxBins=32)
# Evaluate model on training instances and compute training error
predictions = model.predict(data.map(lambda x: x.features))
@@ -293,7 +293,7 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").cache
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"
val maxDepth = 5
-val maxBins = 100
+val maxBins = 32
val model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo, impurity,
maxDepth, maxBins)
@@ -338,7 +338,7 @@ JavaSparkContext sc = new JavaSparkContext(sparkConf);
HashMap categoricalFeaturesInfo = new HashMap();
String impurity = "variance";
Integer maxDepth = 5;
-Integer maxBins = 100;
+Integer maxBins = 32;
// Train a DecisionTree model.
final DecisionTreeModel model = DecisionTree.trainRegressor(data,
@@ -380,7 +380,7 @@ data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt').cache()
# Train a DecisionTree model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainRegressor(data, categoricalFeaturesInfo={},
- impurity='variance', maxDepth=5, maxBins=100)
+ impurity='variance', maxDepth=5, maxBins=32)
# Evaluate model on training instances and compute training error
predictions = model.predict(data.map(lambda x: x.features))
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 943f06b114cb9..d8b22f3663d08 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -125,6 +125,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
the environment of the executor launcher.
+
+ spark.yarn.containerLauncherMaxThreads |
+ 25 |
+
+ The maximum number of threads to use in the application master for launching executor containers.
+ |
+
# Launching Spark on YARN
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 8f7fb5431cfb6..d83efa4bab324 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -13,10 +13,10 @@ title: Spark SQL Programming Guide
Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using
Spark. At the core of this component is a new type of RDD,
-[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed
-[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects along with
+[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of
+[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects, along with
a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table
-in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io)
+in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io)
file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`.
@@ -26,10 +26,10 @@ All of the examples on this page use sample data included in the Spark distribut
Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using
Spark. At the core of this component is a new type of RDD,
-[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed
-[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects along with
+[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of
+[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with
a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table
-in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io)
+in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io)
file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
@@ -37,10 +37,10 @@ file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](
Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using
Spark. At the core of this component is a new type of RDD,
-[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed
-[Row](api/python/pyspark.sql.Row-class.html) objects along with
+[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of
+[Row](api/python/pyspark.sql.Row-class.html) objects, along with
a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table
-in a traditional relational database. A SchemaRDD can be created from an existing RDD, [Parquet](http://parquet.io)
+in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io)
file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell.
@@ -68,6 +68,16 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.createSchemaRDD
{% endhighlight %}
+In addition to the basic SQLContext, you can also create a HiveContext, which provides a
+superset of the functionality provided by the basic SQLContext. Additional features include
+the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the
+ability to read data from Hive tables. To use a HiveContext, you do not need to have an
+existing Hive setup, and all of the data sources available to a SQLContext are still available.
+HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default
+Spark build. If these dependencies are not a problem for your application then using HiveContext
+is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to
+feature parity with a HiveContext.
+
@@ -81,6 +91,16 @@ JavaSparkContext sc = ...; // An existing JavaSparkContext.
JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc);
{% endhighlight %}
+In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict
+super set of the functionality provided by the basic SQLContext. Additional features include
+the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the
+ability to read data from Hive tables. To use a HiveContext, you do not need to have an
+existing Hive setup, and all of the data sources available to a SQLContext are still available.
+HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default
+Spark build. If these dependencies are not a problem for your application then using HiveContext
+is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to
+feature parity with a HiveContext.
+
@@ -94,36 +114,52 @@ from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
{% endhighlight %}
+In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict
+super set of the functionality provided by the basic SQLContext. Additional features include
+the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the
+ability to read data from Hive tables. To use a HiveContext, you do not need to have an
+existing Hive setup, and all of the data sources available to a SQLContext are still available.
+HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default
+Spark build. If these dependencies are not a problem for your application then using HiveContext
+is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to
+feature parity with a HiveContext.
+
+The specific variant of SQL that is used to parse queries can also be selected using the
+`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on
+a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect
+available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the
+default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete,
+ this is recommended for most use cases.
+
# Data Sources
-
-
Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface.
-Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources.
-
-
-
-Spark SQL supports operating on a variety of data sources through the `JavaSchemaRDD` interface.
-Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources.
-
-
-
-Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface.
-Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources.
-
-
+A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table.
+Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section
+describes the various methods for loading data into a SchemaRDD.
## RDDs
+Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first
+method uses reflection to infer the schema of an RDD that contains specific types of objects. This
+reflection based approach leads to more concise code and works well when you already know the schema
+while writing your Spark application.
+
+The second method for creating SchemaRDDs is through a programmatic interface that allows you to
+construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows
+you to construct SchemaRDDs when the columns and their types are not known until runtime.
+
+### Inferring the Schema Using Reflection
-One type of table that is supported by Spark SQL is an RDD of Scala case classes. The case class
+The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes
+to a SchemaRDD. The case class
defines the schema of the table. The names of the arguments to the case class are read using
reflection and become the names of the columns. Case classes can also be nested or contain complex
types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be
@@ -156,8 +192,9 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
-One type of table that is supported by Spark SQL is an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly). The BeanInfo
-defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain
+Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly)
+into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table.
+Currently, Spark SQL does not support JavaBeans that contain
nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a
class that implements Serializable and has getters and setters for all of its fields.
@@ -192,7 +229,7 @@ for the JavaBean.
{% highlight java %}
// sc is an existing JavaSparkContext.
-JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc)
+JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc);
// Load a text file and convert each line to a JavaBean.
JavaRDD
people = sc.textFile("examples/src/main/resources/people.txt").map(
@@ -229,24 +266,24 @@ List teenagerNames = teenagers.map(new Function() {
-One type of table that is supported by Spark SQL is an RDD of dictionaries. The keys of the
-dictionary define the columns names of the table, and the types are inferred by looking at the first
-row. Any RDD of dictionaries can converted to a SchemaRDD and then registered as a table. Tables
-can be used in subsequent SQL statements.
+Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of
+key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table,
+and the types are inferred by looking at the first row. Since we currently only look at the first
+row, it is important that there is no missing data in the first row of the RDD. In future versions we
+plan to more completely infer the schema by looking at more data, similar to the inference that is
+performed on JSON files.
{% highlight python %}
# sc is an existing SparkContext.
-from pyspark.sql import SQLContext
+from pyspark.sql import SQLContext, Row
sqlContext = SQLContext(sc)
# Load a text file and convert each line to a dictionary.
lines = sc.textFile("examples/src/main/resources/people.txt")
parts = lines.map(lambda l: l.split(","))
-people = parts.map(lambda p: {"name": p[0], "age": int(p[1])})
+people = parts.map(lambda p: Row(name=p[0], age=int(p[1])))
# Infer the schema, and register the SchemaRDD as a table.
-# In future versions of PySpark we would like to add support for registering RDDs with other
-# datatypes as tables
schemaPeople = sqlContext.inferSchema(people)
schemaPeople.registerTempTable("people")
@@ -263,15 +300,191 @@ for teenName in teenNames.collect():
-**Note that Spark SQL currently uses a very basic SQL parser.**
-Users that want a more complete dialect of SQL should look at the HiveQL support provided by
-`HiveContext`.
+### Programmatically Specifying the Schema
+
+
+
+
+
+When case classes cannot be defined ahead of time (for example,
+the structure of records is encoded in a string, or a text dataset will be parsed
+and fields will be projected differently for different users),
+a `SchemaRDD` can be created programmatically with three steps.
+
+1. Create an RDD of `Row`s from the original RDD;
+2. Create the schema represented by a `StructType` matching the structure of
+`Row`s in the RDD created in Step 1.
+3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
+by `SQLContext`.
+
+For example:
+{% highlight scala %}
+// sc is an existing SparkContext.
+val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+
+// Create an RDD
+val people = sc.textFile("examples/src/main/resources/people.txt")
+
+// The schema is encoded in a string
+val schemaString = "name age"
+
+// Import Spark SQL data types and Row.
+import org.apache.spark.sql._
+
+// Generate the schema based on the string of schema
+val schema =
+ StructType(
+ schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true)))
+
+// Convert records of the RDD (people) to Rows.
+val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))
+
+// Apply the schema to the RDD.
+val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema)
+
+// Register the SchemaRDD as a table.
+peopleSchemaRDD.registerTempTable("people")
+
+// SQL statements can be run by using the sql methods provided by sqlContext.
+val results = sqlContext.sql("SELECT name FROM people")
+
+// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+// The columns of a row in the result can be accessed by ordinal.
+results.map(t => "Name: " + t(0)).collect().foreach(println)
+{% endhighlight %}
+
+
+
+
+
+
+When JavaBean classes cannot be defined ahead of time (for example,
+the structure of records is encoded in a string, or a text dataset will be parsed and
+fields will be projected differently for different users),
+a `SchemaRDD` can be created programmatically with three steps.
+
+1. Create an RDD of `Row`s from the original RDD;
+2. Create the schema represented by a `StructType` matching the structure of
+`Row`s in the RDD created in Step 1.
+3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
+by `JavaSQLContext`.
+
+For example:
+{% highlight java %}
+// Import factory methods provided by DataType.
+import org.apache.spark.sql.api.java.DataType
+// Import StructType and StructField
+import org.apache.spark.sql.api.java.StructType
+import org.apache.spark.sql.api.java.StructField
+// Import Row.
+import org.apache.spark.sql.api.java.Row
+
+// sc is an existing JavaSparkContext.
+JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc);
+
+// Load a text file and convert each line to a JavaBean.
+JavaRDD people = sc.textFile("examples/src/main/resources/people.txt");
+
+// The schema is encoded in a string
+String schemaString = "name age";
+
+// Generate the schema based on the string of schema
+List fields = new ArrayList();
+for (String fieldName: schemaString.split(" ")) {
+ fields.add(DataType.createStructField(fieldName, DataType.StringType, true));
+}
+StructType schema = DataType.createStructType(fields);
+
+// Convert records of the RDD (people) to Rows.
+JavaRDD rowRDD = people.map(
+ new Function() {
+ public Row call(String record) throws Exception {
+ String[] fields = record.split(",");
+ return Row.create(fields[0], fields[1].trim());
+ }
+ });
+
+// Apply the schema to the RDD.
+JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema);
+
+// Register the SchemaRDD as a table.
+peopleSchemaRDD.registerTempTable("people");
+
+// SQL can be run over RDDs that have been registered as tables.
+JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people");
+
+// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+// The columns of a row in the result can be accessed by ordinal.
+List names = results.map(new Function() {
+ public String call(Row row) {
+ return "Name: " + row.getString(0);
+ }
+}).collect();
+
+{% endhighlight %}
+
+
+
+
+
+When a dictionary of kwargs cannot be defined ahead of time (for example,
+the structure of records is encoded in a string, or a text dataset will be parsed and
+fields will be projected differently for different users),
+a `SchemaRDD` can be created programmatically with three steps.
+
+1. Create an RDD of tuples or lists from the original RDD;
+2. Create the schema represented by a `StructType` matching the structure of
+tuples or lists in the RDD created in the step 1.
+3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`.
+
+For example:
+{% highlight python %}
+# Import SQLContext and data types
+from pyspark.sql import *
+
+# sc is an existing SparkContext.
+sqlContext = SQLContext(sc)
+
+# Load a text file and convert each line to a tuple.
+lines = sc.textFile("examples/src/main/resources/people.txt")
+parts = lines.map(lambda l: l.split(","))
+people = parts.map(lambda p: (p[0], p[1].strip()))
+
+# The schema is encoded in a string.
+schemaString = "name age"
+
+fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split()]
+schema = StructType(fields)
+
+# Apply the schema to the RDD.
+schemaPeople = sqlContext.applySchema(people, schema)
+
+# Register the SchemaRDD as a table.
+schemaPeople.registerTempTable("people")
+
+# SQL can be run over SchemaRDDs that have been registered as a table.
+results = sqlContext.sql("SELECT name FROM people")
+
+# The results of SQL queries are RDDs and support all the normal RDD operations.
+names = results.map(lambda p: "Name: " + p.name)
+for name in names.collect():
+ print name
+{% endhighlight %}
+
+
+
+
+
## Parquet Files
[Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems.
Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema
-of the original data. Using the data from the above example:
+of the original data.
+
+### Loading Data Programmatically
+
+Using the data from the above example:
@@ -349,7 +562,40 @@ for teenName in teenNames.collect():
-
+
+
+### Configuration
+
+Configuration of Parquet can be done using the `setConf` method on SQLContext or by running
+`SET key=value` commands using SQL.
+
+
+Property Name | Default | Meaning |
+
+ spark.sql.parquet.binaryAsString |
+ false |
+
+ Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do
+ not differentiate between binary data and strings when writing out the Parquet schema. This
+ flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems.
+ |
+
+
+ spark.sql.parquet.cacheMetadata |
+ false |
+
+ Turns on caching of Parquet schema metadata. Can speed up querying of static data.
+ |
+
+
+ spark.sql.parquet.compression.codec |
+ snappy |
+
+ Sets the compression codec use when writing Parquet files. Acceptable values include:
+ uncompressed, snappy, gzip, lzo.
+ |
+
+
## JSON Datasets
@@ -493,13 +739,13 @@ directory.
{% highlight scala %}
// sc is an existing SparkContext.
-val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc)
+val sqlContext = new org.apache.spark.sql.hive.HiveContext(sc)
-hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
-hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
// Queries are expressed in HiveQL
-hiveContext.sql("FROM src SELECT key, value").collect().foreach(println)
+sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
{% endhighlight %}
@@ -513,13 +759,13 @@ expressed in HiveQL.
{% highlight java %}
// sc is an existing JavaSparkContext.
-JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc);
+JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc);
-hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)");
-hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src");
+sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)");
+sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src");
// Queries are expressed in HiveQL.
-Row[] results = hiveContext.sql("FROM src SELECT key, value").collect();
+Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
{% endhighlight %}
@@ -535,49 +781,101 @@ expressed in HiveQL.
{% highlight python %}
# sc is an existing SparkContext.
from pyspark.sql import HiveContext
-hiveContext = HiveContext(sc)
+sqlContext = HiveContext(sc)
-hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
-hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
# Queries can be expressed in HiveQL.
-results = hiveContext.sql("FROM src SELECT key, value").collect()
+results = sqlContext.sql("FROM src SELECT key, value").collect()
{% endhighlight %}
-# Writing Language-Integrated Relational Queries
+# Performance Tuning
-**Language-Integrated queries are currently only supported in Scala.**
+For some workloads it is possible to improve performance by either caching data in memory, or by
+turning on some experimental options.
-Spark SQL also supports a domain specific language for writing queries. Once again,
-using the data from the above examples:
+## Caching Data In Memory
-{% highlight scala %}
-// sc is an existing SparkContext.
-val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-// Importing the SQL context gives access to all the public SQL functions and implicit conversions.
-import sqlContext._
-val people: RDD[Person] = ... // An RDD of case class objects, from the first example.
-
-// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19'
-val teenagers = people.where('age >= 10).where('age <= 19).select('name)
-teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
-{% endhighlight %}
-
-The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers
-prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are
-evaluated by the SQL execution engine. A full list of the functions supported can be found in the
-[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD).
+Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`.
+Then Spark SQL will scan only required columns and will automatically tune compression to minimize
+memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory.
-
+Note that if you call `cache` rather than `cacheTable`, tables will _not_ be cached using
+the in-memory columnar format, and therefore `cacheTable` is strongly recommended for this use case.
+
+Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running
+`SET key=value` commands using SQL.
+
+
+Property Name | Default | Meaning |
+
+ spark.sql.inMemoryColumnarStorage.compressed |
+ false |
+
+ When set to true Spark SQL will automatically select a compression codec for each column based
+ on statistics of the data.
+ |
+
+
+ spark.sql.inMemoryColumnarStorage.batchSize |
+ 1000 |
+
+ Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization
+ and compression, but risk OOMs when caching data.
+ |
+
+
+
+
+## Other Configuration Options
+
+The following options can also be used to tune the performance of query execution. It is possible
+that these options will be deprecated in future release as more optimizations are performed automatically.
+
+
+ Property Name | Default | Meaning |
+
+ spark.sql.autoBroadcastJoinThreshold |
+ 10000 |
+
+ Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
+ performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
+ statistics are only supported for Hive Metastore tables where the command
+ `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run.
+ |
+
+
+ spark.sql.codegen |
+ false |
+
+ When true, code will be dynamically generated at runtime for expression evaluation in a specific
+ query. For some queries with complicated expression this option can lead to significant speed-ups.
+ However, for simple queries this can actually slow down query execution.
+ |
+
+
+ spark.sql.shuffle.partitions |
+ 200 |
+
+ Configures the number of partitions to use when shuffling data for joins or aggregations.
+ |
+
+
+
+# Other SQL Interfaces
+
+Spark SQL also supports interfaces for running SQL queries directly without the need to write any
+code.
## Running the Thrift JDBC server
The Thrift JDBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2)
-in Hive 0.12. You can test the JDBC server with the beeline script comes with either Spark or Hive 0.12.
+in Hive 0.12. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.12.
To start the JDBC server, run the following in the Spark directory:
@@ -600,19 +898,36 @@ your machine and a blank password. For secure mode, please follow the instructio
Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
-You may also use the beeline script comes with Hive.
+You may also use the beeline script that comes with Hive.
+## Running the Spark SQL CLI
+
+The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute
+queries input from the command line. Note that the Spark SQL CLI cannot talk to the Thrift JDBC server.
+
+To start the Spark SQL CLI, run the following in the Spark directory:
+
+ ./bin/spark-sql
+
+Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
+You may run `./bin/spark-sql --help` for a complete list of all available
+options.
+
+# Compatibility with Other Systems
+
+## Migration Guide for Shark User
+
+### Scheduling
+s
To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session,
users can set the `spark.sql.thriftserver.scheduler.pool` variable:
SET spark.sql.thriftserver.scheduler.pool=accounting;
-### Migration Guide for Shark Users
-
-#### Reducer number
+### Reducer number
In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark
-SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value
+SQL deprecates this property in favor of `spark.sql.shuffle.partitions`, whose default value
is 200. Users may customize this property via `SET`:
SET spark.sql.shuffle.partitions=10;
@@ -625,7 +940,7 @@ You may also put this property in `hive-site.xml` to override the default value.
For now, the `mapred.reduce.tasks` property is still recognized, and is converted to
`spark.sql.shuffle.partitions` automatically.
-#### Caching
+### Caching
The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no
longer automatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to
@@ -634,9 +949,9 @@ let user control table caching explicitly:
CACHE TABLE logs_last_month;
UNCACHE TABLE logs_last_month;
-**NOTE:** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary",
-but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be
-cached, you may simply count the table immediately after executing `CACHE TABLE`:
+**NOTE:** `CACHE TABLE tbl` is lazy, similar to `.cache` on an RDD. This command only marks `tbl` to ensure that
+partitions are cached when calculated but doesn't actually cache it until a query that touches `tbl` is executed.
+To force the table to be cached, you may simply count the table immediately after executing `CACHE TABLE`:
CACHE TABLE logs_last_month;
SELECT COUNT(1) FROM logs_last_month;
@@ -647,15 +962,18 @@ Several caching related features are not supported yet:
* RDD reloading
* In-memory cache write through policy
-### Compatibility with Apache Hive
+## Compatibility with Apache Hive
+
+Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark
+SQL is based on Hive 0.12.0.
#### Deploying in Existing Hive Warehouses
-Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive
+The Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive
installations. You do not need to modify your existing Hive Metastore or change the data placement
or partitioning of your tables.
-#### Supported Hive Features
+### Supported Hive Features
Spark SQL supports the vast majority of Hive features, such as:
@@ -705,13 +1023,14 @@ Spark SQL supports the vast majority of Hive features, such as:
* `MAP<>`
* `STRUCT<>`
-#### Unsupported Hive Functionality
+### Unsupported Hive Functionality
Below is a list of Hive features that we don't support yet. Most of these features are rarely used
in Hive deployments.
**Major Hive Features**
+* Spark SQL does not currently support inserting to tables using dynamic partitioning.
* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
doesn't support buckets yet.
@@ -721,11 +1040,11 @@ in Hive deployments.
have the same input format.
* Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions
(e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple.
-* `UNIONTYPE`
+* `UNION` type and `DATE` type
* Unique join
* Single query multi insert
* Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at
- the moment.
+ the moment and only supports populating the sizeInBytes field of the hive metastore.
**Hive Input/Output Formats**
@@ -735,7 +1054,7 @@ in Hive deployments.
**Hive Optimizations**
A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are
-not necessary due to Spark SQL's in-memory computational model. Others are slotted for future
+less important due to Spark SQL's in-memory computational model. Others are slotted for future
releases of Spark SQL.
* Block level bitmap indexes and virtual columns (used to build indexes)
@@ -743,8 +1062,7 @@ releases of Spark SQL.
Hive automatically converts the join into a map join. We are adding this auto conversion in the
next release.
* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you
- need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". We are going to add auto-setting of parallelism in the
- next release.
+ need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`".
* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still
launches tasks to compute the result.
* Skew data flag: Spark SQL does not follow the skew data flags in Hive.
@@ -753,25 +1071,471 @@ releases of Spark SQL.
Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS
metadata. Spark SQL does not support that.
-## Running the Spark SQL CLI
+# Writing Language-Integrated Relational Queries
-The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute
-queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server.
+**Language-Integrated queries are experimental and currently only supported in Scala.**
-To start the Spark SQL CLI, run the following in the Spark directory:
+Spark SQL also supports a domain specific language for writing queries. Once again,
+using the data from the above examples:
- ./bin/spark-sql
+{% highlight scala %}
+// sc is an existing SparkContext.
+val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+// Importing the SQL context gives access to all the public SQL functions and implicit conversions.
+import sqlContext._
+val people: RDD[Person] = ... // An RDD of case class objects, from the first example.
-Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
-You may run `./bin/spark-sql --help` for a complete list of all available
-options.
+// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19'
+val teenagers = people.where('age >= 10).where('age <= 19).select('name)
+teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
+{% endhighlight %}
+
+The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers
+prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are
+evaluated by the SQL execution engine. A full list of the functions supported can be found in the
+[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD).
-# Cached tables
+
-Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`.
-Then Spark SQL will scan only required columns and will automatically tune compression to minimize
-memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory.
+# Spark SQL DataType Reference
+
+* Numeric types
+ - `ByteType`: Represents 1-byte signed integer numbers.
+ The range of numbers is from `-128` to `127`.
+ - `ShortType`: Represents 2-byte signed integer numbers.
+ The range of numbers is from `-32768` to `32767`.
+ - `IntegerType`: Represents 4-byte signed integer numbers.
+ The range of numbers is from `-2147483648` to `2147483647`.
+ - `LongType`: Represents 8-byte signed integer numbers.
+ The range of numbers is from `-9223372036854775808` to `9223372036854775807`.
+ - `FloatType`: Represents 4-byte single-precision floating point numbers.
+ - `DoubleType`: Represents 8-byte double-precision floating point numbers.
+ - `DecimalType`:
+* String type
+ - `StringType`: Represents character string values.
+* Binary type
+ - `BinaryType`: Represents byte sequence values.
+* Boolean type
+ - `BooleanType`: Represents boolean values.
+* Datetime type
+ - `TimestampType`: Represents values comprising values of fields year, month, day,
+ hour, minute, and second.
+* Complex types
+ - `ArrayType(elementType, containsNull)`: Represents values comprising a sequence of
+ elements with the type of `elementType`. `containsNull` is used to indicate if
+ elements in a `ArrayType` value can have `null` values.
+ - `MapType(keyType, valueType, valueContainsNull)`:
+ Represents values comprising a set of key-value pairs. The data type of keys are
+ described by `keyType` and the data type of values are described by `valueType`.
+ For a `MapType` value, keys are not allowed to have `null` values. `valueContainsNull`
+ is used to indicate if values of a `MapType` value can have `null` values.
+ - `StructType(fields)`: Represents values with the structure described by
+ a sequence of `StructField`s (`fields`).
+ * `StructField(name, dataType, nullable)`: Represents a field in a `StructType`.
+ The name of a field is indicated by `name`. The data type of a field is indicated
+ by `dataType`. `nullable` is used to indicate if values of this fields can have
+ `null` values.
+
+
+
+
+All data types of Spark SQL are located in the package `org.apache.spark.sql`.
+You can access them by doing
+{% highlight scala %}
+import org.apache.spark.sql._
+{% endhighlight %}
+
+
+
+ Data type |
+ Value type in Scala |
+ API to access or create a data type |
+
+ ByteType |
+ Byte |
+
+ ByteType
+ |
+
+
+ ShortType |
+ Short |
+
+ ShortType
+ |
+
+
+ IntegerType |
+ Int |
+
+ IntegerType
+ |
+
+
+ LongType |
+ Long |
+
+ LongType
+ |
+
+
+ FloatType |
+ Float |
+
+ FloatType
+ |
+
+
+ DoubleType |
+ Double |
+
+ DoubleType
+ |
+
+
+ DecimalType |
+ scala.math.sql.BigDecimal |
+
+ DecimalType
+ |
+
+
+ StringType |
+ String |
+
+ StringType
+ |
+
+
+ BinaryType |
+ Array[Byte] |
+
+ BinaryType
+ |
+
+
+ BooleanType |
+ Boolean |
+
+ BooleanType
+ |
+
+
+ TimestampType |
+ java.sql.Timestamp |
+
+ TimestampType
+ |
+
+
+ ArrayType |
+ scala.collection.Seq |
+
+ ArrayType(elementType, [containsNull])
+ Note: The default value of containsNull is false.
+ |
+
+
+ MapType |
+ scala.collection.Map |
+
+ MapType(keyType, valueType, [valueContainsNull])
+ Note: The default value of valueContainsNull is true.
+ |
+
+
+ StructType |
+ org.apache.spark.sql.Row |
+
+ StructType(fields)
+ Note: fields is a Seq of StructFields. Also, two fields with the same
+ name are not allowed.
+ |
+
+
+ StructField |
+ The value type in Scala of the data type of this field
+ (For example, Int for a StructField with the data type IntegerType) |
+
+ StructField(name, dataType, nullable)
+ |
+
+
+
+
+
+
+
+All data types of Spark SQL are located in the package of
+`org.apache.spark.sql.api.java`. To access or create a data type,
+please use factory methods provided in
+`org.apache.spark.sql.api.java.DataType`.
+
+
+
+ Data type |
+ Value type in Java |
+ API to access or create a data type |
+
+ ByteType |
+ byte or Byte |
+
+ DataType.ByteType
+ |
+
+
+ ShortType |
+ short or Short |
+
+ DataType.ShortType
+ |
+
+
+ IntegerType |
+ int or Integer |
+
+ DataType.IntegerType
+ |
+
+
+ LongType |
+ long or Long |
+
+ DataType.LongType
+ |
+
+
+ FloatType |
+ float or Float |
+
+ DataType.FloatType
+ |
+
+
+ DoubleType |
+ double or Double |
+
+ DataType.DoubleType
+ |
+
+
+ DecimalType |
+ java.math.BigDecimal |
+
+ DataType.DecimalType
+ |
+
+
+ StringType |
+ String |
+
+ DataType.StringType
+ |
+
+
+ BinaryType |
+ byte[] |
+
+ DataType.BinaryType
+ |
+
+
+ BooleanType |
+ boolean or Boolean |
+
+ DataType.BooleanType
+ |
+
+
+ TimestampType |
+ java.sql.Timestamp |
+
+ DataType.TimestampType
+ |
+
+
+ ArrayType |
+ java.util.List |
+
+ DataType.createArrayType(elementType)
+ Note: The value of containsNull will be false
+ DataType.createArrayType(elementType, containsNull).
+ |
+
+
+ MapType |
+ java.util.Map |
+
+ DataType.createMapType(keyType, valueType)
+ Note: The value of valueContainsNull will be true.
+ DataType.createMapType(keyType, valueType, valueContainsNull)
+ |
+
+
+ StructType |
+ org.apache.spark.sql.api.java |
+
+ DataType.createStructType(fields)
+ Note: fields is a List or an array of StructFields.
+ Also, two fields with the same name are not allowed.
+ |
+
+
+ StructField |
+ The value type in Java of the data type of this field
+ (For example, int for a StructField with the data type IntegerType) |
+
+ DataType.createStructField(name, dataType, nullable)
+ |
+
+
+
+
+
+
+
+All data types of Spark SQL are located in the package of `pyspark.sql`.
+You can access them by doing
+{% highlight python %}
+from pyspark.sql import *
+{% endhighlight %}
+
+
+
+ Data type |
+ Value type in Python |
+ API to access or create a data type |
+
+ ByteType |
+
+ int or long
+ Note: Numbers will be converted to 1-byte signed integer numbers at runtime.
+ Please make sure that numbers are within the range of -128 to 127.
+ |
+
+ ByteType()
+ |
+
+
+ ShortType |
+
+ int or long
+ Note: Numbers will be converted to 2-byte signed integer numbers at runtime.
+ Please make sure that numbers are within the range of -32768 to 32767.
+ |
+
+ ShortType()
+ |
+
+
+ IntegerType |
+ int or long |
+
+ IntegerType()
+ |
+
+
+ LongType |
+
+ long
+ Note: Numbers will be converted to 8-byte signed integer numbers at runtime.
+ Please make sure that numbers are within the range of
+ -9223372036854775808 to 9223372036854775807.
+ Otherwise, please convert data to decimal.Decimal and use DecimalType.
+ |
+
+ LongType()
+ |
+
+
+ FloatType |
+
+ float
+ Note: Numbers will be converted to 4-byte single-precision floating
+ point numbers at runtime.
+ |
+
+ FloatType()
+ |
+
+
+ DoubleType |
+ float |
+
+ DoubleType()
+ |
+
+
+ DecimalType |
+ decimal.Decimal |
+
+ DecimalType()
+ |
+
+
+ StringType |
+ string |
+
+ StringType()
+ |
+
+
+ BinaryType |
+ bytearray |
+
+ BinaryType()
+ |
+
+
+ BooleanType |
+ bool |
+
+ BooleanType()
+ |
+
+
+ TimestampType |
+ datetime.datetime |
+
+ TimestampType()
+ |
+
+
+ ArrayType |
+ list, tuple, or array |
+
+ ArrayType(elementType, [containsNull])
+ Note: The default value of containsNull is False.
+ |
+
+
+ MapType |
+ dict |
+
+ MapType(keyType, valueType, [valueContainsNull])
+ Note: The default value of valueContainsNull is True.
+ |
+
+
+ StructType |
+ list or tuple |
+
+ StructType(fields)
+ Note: fields is a Seq of StructFields. Also, two fields with the same
+ name are not allowed.
+ |
+
+
+ StructField |
+ The value type in Python of the data type of this field
+ (For example, Int for a StructField with the data type IntegerType) |
+
+ StructField(name, dataType, nullable)
+ |
+
+
+
+
+
+
-Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in
-in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to
-cache tables.
diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md
new file mode 100644
index 0000000000000..c39ef1ce59e1c
--- /dev/null
+++ b/docs/storage-openstack-swift.md
@@ -0,0 +1,152 @@
+---
+layout: global
+title: Accessing OpenStack Swift from Spark
+---
+
+Spark's support for Hadoop InputFormat allows it to process data in OpenStack Swift using the
+same URI formats as in Hadoop. You can specify a path in Swift as input through a
+URI of the form swift://container.PROVIDER/path
. You will also need to set your
+Swift security credentials, through core-site.xml
or via
+SparkContext.hadoopConfiguration
.
+Current Swift driver requires Swift to use Keystone authentication method.
+
+# Configuring Swift for Better Data Locality
+
+Although not mandatory, it is recommended to configure the proxy server of Swift with
+list_endpoints
to have better data locality. More information is
+[available here](https://github.com/openstack/swift/blob/master/swift/common/middleware/list_endpoints.py).
+
+
+# Dependencies
+
+The Spark application should include hadoop-openstack
dependency.
+For example, for Maven support, add the following to the pom.xml
file:
+
+{% highlight xml %}
+
+ ...
+
+ org.apache.hadoop
+ hadoop-openstack
+ 2.3.0
+
+ ...
+
+{% endhighlight %}
+
+
+# Configuration Parameters
+
+Create core-site.xml
and place it inside Spark's conf
directory.
+There are two main categories of parameters that should to be configured: declaration of the
+Swift driver and the parameters that are required by Keystone.
+
+Configuration of Hadoop to use Swift File system achieved via
+
+
+Property Name | Value |
+
+ fs.swift.impl |
+ org.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem |
+
+
+
+Additional parameters required by Keystone (v2.0) and should be provided to the Swift driver. Those
+parameters will be used to perform authentication in Keystone to access Swift. The following table
+contains a list of Keystone mandatory parameters. PROVIDER
can be any name.
+
+
+Property Name | Meaning | Required |
+
+ fs.swift.service.PROVIDER.auth.url |
+ Keystone Authentication URL |
+ Mandatory |
+
+
+ fs.swift.service.PROVIDER.auth.endpoint.prefix |
+ Keystone endpoints prefix |
+ Optional |
+
+
+ fs.swift.service.PROVIDER.tenant |
+ Tenant |
+ Mandatory |
+
+
+ fs.swift.service.PROVIDER.username |
+ Username |
+ Mandatory |
+
+
+ fs.swift.service.PROVIDER.password |
+ Password |
+ Mandatory |
+
+
+ fs.swift.service.PROVIDER.http.port |
+ HTTP port |
+ Mandatory |
+
+
+ fs.swift.service.PROVIDER.region |
+ Keystone region |
+ Mandatory |
+
+
+ fs.swift.service.PROVIDER.public |
+ Indicates if all URLs are public |
+ Mandatory |
+
+
+
+For example, assume PROVIDER=SparkTest
and Keystone contains user tester
with password testing
+defined for tenant test
. Then core-site.xml
should include:
+
+{% highlight xml %}
+
+
+ fs.swift.impl
+ org.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem
+
+
+ fs.swift.service.SparkTest.auth.url
+ http://127.0.0.1:5000/v2.0/tokens
+
+
+ fs.swift.service.SparkTest.auth.endpoint.prefix
+ endpoints
+
+ fs.swift.service.SparkTest.http.port
+ 8080
+
+
+ fs.swift.service.SparkTest.region
+ RegionOne
+
+
+ fs.swift.service.SparkTest.public
+ true
+
+
+ fs.swift.service.SparkTest.tenant
+ test
+
+
+ fs.swift.service.SparkTest.username
+ tester
+
+
+ fs.swift.service.SparkTest.password
+ testing
+
+
+{% endhighlight %}
+
+Notice that
+fs.swift.service.PROVIDER.tenant
,
+fs.swift.service.PROVIDER.username
,
+fs.swift.service.PROVIDER.password
contains sensitive information and keeping them in
+core-site.xml
is not always a good approach.
+We suggest to keep those parameters in core-site.xml
for testing purposes when running Spark
+via spark-shell
.
+For job submissions they should be provided via sparkContext.hadoopConfiguration
.
diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md
index 079d4c5550537..c6090d9ec30c7 100644
--- a/docs/streaming-kinesis-integration.md
+++ b/docs/streaming-kinesis-integration.md
@@ -3,8 +3,8 @@ layout: global
title: Spark Streaming + Kinesis Integration
---
[Amazon Kinesis](http://aws.amazon.com/kinesis/) is a fully managed service for real-time processing of streaming data at massive scale.
-The Kinesis input DStream and receiver uses the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License (ASL).
-The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concept of Workers, Checkpoints, and Shard Leases.
+The Kinesis receiver creates an input DStream using the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License (ASL).
+The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concepts of Workers, Checkpoints, and Shard Leases.
Here we explain how to configure Spark Streaming to receive data from Kinesis.
#### Configuring Kinesis
@@ -15,7 +15,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
#### Configuring Spark Streaming Application
-1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information).
+1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information).
groupId = org.apache.spark
artifactId = spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}
@@ -23,10 +23,11 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
**Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your application.**
-2. **Programming:** In the streaming application code, import `KinesisUtils` and create input DStream as follows.
+2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream as follows:
+ import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.kinesis._
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
@@ -34,11 +35,13 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position])
See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the next subsection for instructions to run the example.
+ and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example.
- import org.apache.spark.streaming.flume.*;
+ import org.apache.spark.streaming.Duration;
+ import org.apache.spark.streaming.kinesis.*;
+ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream(
streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]);
@@ -49,36 +52,73 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
- `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region).
+ - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream
- `[checkpoint interval]`: The interval at which the Kinesis client library is going to save its position in the stream. For starters, set it to the same as the batch interval of the streaming application.
+ - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from
+ - The application name used in the streaming context becomes the Kinesis application name
+ - The application name must be unique for a given account and region.
+ - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization.
+ - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table.
- `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see later section and Amazon Kinesis API documentation for more details).
- *Points to remember:*
+ - `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region).
- - The name used in the context of the streaming application must be unique for a given account and region. Changing the app name or stream name could lead to Kinesis errors as only a single logical application can process a single stream.
- - A single Kinesis input DStream can receive many Kinesis shards by spinning up multiple KinesisRecordProcessor threads. Note that there is no correlation between number of shards in Kinesis and the number of partitions in the generated RDDs that is used for processing the data.
- - You never need more KinesisReceivers than the number of shards in your stream as each will spin up at least one KinesisRecordProcessor thread.
- - Horizontal scaling is achieved by autoscaling additional Kinesis input DStreams (separate processes) up to the number of current shards for a given stream, of course.
+ - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application.
-3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide).
+ - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details).
- - A DynamoDB table and CloudWatch namespace are created during KCL initialization using this Kinesis application name. This DynamoDB table lives in the us-east-1 region regardless of the Kinesis endpoint URL. It is used to store KCL's checkpoint information.
- - If you are seeing errors after changing the app name or stream name, it may be necessary to manually delete the DynamoDB table and start from scratch.
+3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide).
+
+ *Points to remember at runtime:*
+
+ - Kinesis data processing is ordered per partition and occurs at-least once per message.
+
+ - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamodDB.
+
+ - A single Kinesis stream shard is processed by one input DStream at a time.
+
+
+
+
+
+
+ - A single Kinesis input DStream can read from multiple shards of a Kinesis stream by creating multiple KinesisRecordProcessor threads.
+
+ - Multiple input DStreams running in separate processes/instances can read from a Kinesis stream.
+
+ - You never need more Kinesis input DStreams than the number of Kinesis stream shards as each input DStream will create at least one KinesisRecordProcessor thread that handles a single shard.
+
+ - Horizontal scaling is achieved by adding/removing Kinesis input DStreams (within a single process or across multiple processes/instances) - up to the total number of Kinesis stream shards per the previous point.
+
+ - The Kinesis input DStream will balance the load between all DStreams - even across processes/instances.
+
+ - The Kinesis input DStream will balance the load during re-shard events (merging and splitting) due to changes in load.
+
+ - As a best practice, it's recommended that you avoid re-shard jitter by over-provisioning when possible.
+
+ - Each Kinesis input DStream maintains its own checkpoint info. See the Kinesis Checkpointing section for more details.
+
+ - There is no correlation between the number of Kinesis stream shards and the number of RDD partitions/shards created across the Spark cluster during input DStream processing. These are 2 independent partitioning schemes.
#### Running the Example
To run the example,
+
- Download Spark source and follow the [instructions](building-with-maven.html) to build Spark with profile *-Pkinesis-asl*.
- mvn -Pkinesis-asl -DskipTests clean package
+ mvn -Pkinesis-asl -DskipTests clean package
+
-- Set up Kinesis stream (see earlier section). Note the name of the Kinesis stream, and the endpoint URL corresponding to the region the stream is based on.
+- Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created.
- Set up the environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_KEY with your AWS credentials.
- In the Spark root directory, run the example as
+
@@ -92,19 +132,19 @@ To run the example,
- This will wait for data to be received from Kinesis.
+ This will wait for data to be received from the Kinesis stream.
-- To generate random string data, in another terminal, run the associated Kinesis data producer.
+- To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer.
bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10
- This will push random words to the Kinesis stream, which should then be received and processed by the running example.
+ This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example.
#### Kinesis Checkpointing
-The Kinesis receiver checkpoints the position of the stream that has been read periodically, so that the system can recover from failures and continue processing where it had left off. Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy.
-
-- If no Kinesis checkpoint info exists, the KinesisReceiver will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
+- Each Kinesis input DStream periodically stores the current position of the stream in the backing DynamoDB table. This allows the system to recover from failures and continue processing where the DStream left off.
-- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running (and no checkpoint info is being stored). In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data.
+- Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy.
-- InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency.
+- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
+- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored).
+- InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 3d4bce49666ed..41f170580f452 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -233,7 +233,7 @@ $ ./bin/run-example streaming.NetworkWordCount localhost 9999
{% highlight bash %}
-$ ./bin/run-example JavaNetworkWordCount localhost 9999
+$ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999
{% endhighlight %}
@@ -262,7 +262,7 @@ hello world
{% highlight bash %}
# TERMINAL 2: RUNNING NetworkWordCount or JavaNetworkWordCount
-$ ./bin/run-example org.apache.spark.examples.streaming.NetworkWordCount localhost 9999
+$ ./bin/run-example streaming.NetworkWordCount localhost 9999
...
-------------------------------------------
Time: 1357008430000 ms
@@ -285,12 +285,22 @@ need to know to write your streaming applications.
## Linking
-To write your own Spark Streaming program, you will have to add the following dependency to your
- SBT or Maven project:
+Similar to Spark, Spark Streaming is available through Maven Central. To write your own Spark Streaming program, you will have to add the following dependency to your SBT or Maven project.
+
+
+
- groupId = org.apache.spark
- artifactId = spark-streaming_{{site.SCALA_BINARY_VERSION}}
- version = {{site.SPARK_VERSION}}
+
+ org.apache.spark
+ spark-streaming_{{site.SCALA_BINARY_VERSION}}
+ {{site.SPARK_VERSION}}
+
+
+
+
+ libraryDependencies += "org.apache.spark" % "spark-streaming_{{site.SCALA_BINARY_VERSION}}" % "{{site.SPARK_VERSION}}"
+
+
For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark
Streaming core
@@ -302,7 +312,7 @@ some of the common ones are as follows.
Source | Artifact |
Kafka | spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} |
Flume | spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} |
- Kinesis
| spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} |
+ Kinesis
| spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Apache Software License] |
Twitter | spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}} |
ZeroMQ | spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}} |
MQTT | spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}} |
@@ -373,7 +383,7 @@ or a special __"local[\*]"__ string to run in local mode. In practice, when runn
you will not want to hardcode `master` in the program,
but rather [launch the application with `spark-submit`](submitting-applications.html) and
receive it there. However, for local testing and unit tests, you can pass "local[*]" to run Spark Streaming
-in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`.
+in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`.
The batch interval must be set based on the latency requirements of your application
and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size)
@@ -447,11 +457,12 @@ Spark Streaming has two categories of streaming sources.
- *Basic sources*: Sources directly available in the StreamingContext API. Example: file systems, socket connections, and Akka actors.
- *Advanced sources*: Sources like Kafka, Flume, Kinesis, Twitter, etc. are available through extra utility classes. These require linking against extra dependencies as discussed in the [linking](#linking) section.
-Every input DStream (except file stream) is associated with a single [Receiver](api/scala/index.html#org.apache.spark.streaming.receiver.Receiver) object which receives the data from a source and stores it in Spark's memory for processing. A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are:
+Every input DStream (except file stream) is associated with a single [Receiver](api/scala/index.html#org.apache.spark.streaming.receiver.Receiver) object which receives the data from a source and stores it in Spark's memory for processing. So every input DStream receives a single stream of data. Note that in a streaming application, you can create multiple input DStreams to receive multiple streams of data in parallel. This is discussed later in the [Performance Tuning](#level-of-parallelism-in-data-receiving) section.
+
+A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are:
##### Points to remember:
{:.no_toc}
-
- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them.
- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data.
@@ -1089,9 +1100,34 @@ parallelizing the data receiving. Note that each input DStream
creates a single receiver (running on a worker machine) that receives a single stream of data.
Receiving multiple data streams can therefore be achieved by creating multiple input DStreams
and configuring them to receive different partitions of the data stream from the source(s).
-For example, a single Kafka input stream receiving two topics of data can be split into two
+For example, a single Kafka input DStream receiving two topics of data can be split into two
Kafka input streams, each receiving only one topic. This would run two receivers on two workers,
-thus allowing data to be received in parallel, and increasing overall throughput.
+thus allowing data to be received in parallel, and increasing overall throughput. These multiple
+DStream can be unioned together to create a single DStream. Then the transformations that was
+being applied on the single input DStream can applied on the unified stream. This is done as follows.
+
+
+
+{% highlight scala %}
+val numStreams = 5
+val kafkaStreams = (1 to numStreams).map { i => KafkaUtils.createStream(...) }
+val unifiedStream = streamingContext.union(kafkaStreams)
+unifiedStream.print()
+{% endhighlight %}
+
+
+{% highlight java %}
+int numStreams = 5;
+List> kafkaStreams = new ArrayList>(numStreams);
+for (int i = 0; i < numStreams; i++) {
+ kafkaStreams.add(KafkaUtils.createStream(...));
+}
+JavaPairDStream unifiedStream = streamingContext.union(kafkaStreams.get(0), kafkaStreams.subList(1, kafkaStreams.size()));
+unifiedStream.print();
+{% endhighlight %}
+
+
+
Another parameter that should be considered is the receiver's blocking interval. For most receivers,
the received data is coalesced together into large blocks of data before storing inside Spark's memory.
@@ -1107,7 +1143,7 @@ before further processing.
### Level of Parallelism in Data Processing
{:.no_toc}
-Cluster resources maybe under-utilized if the number of parallel tasks used in any stage of the
+Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the
computation is not high enough. For example, for distributed reduce operations like `reduceByKey`
and `reduceByKeyAndWindow`, the default number of parallel tasks is decided by the [config property]
(configuration.html#spark-properties) `spark.default.parallelism`. You can pass the level of
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 1670faca4a480..bfd07593b92ed 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -38,9 +38,12 @@
from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType
from boto import ec2
+DEFAULT_SPARK_VERSION = "1.0.0"
+
# A URL prefix from which to fetch AMI information
AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list"
+
class UsageError(Exception):
pass
@@ -56,10 +59,10 @@ def parse_args():
help="Show this help message and exit")
parser.add_option(
"-s", "--slaves", type="int", default=1,
- help="Number of slaves to launch (default: 1)")
+ help="Number of slaves to launch (default: %default)")
parser.add_option(
"-w", "--wait", type="int", default=120,
- help="Seconds to wait for nodes to start (default: 120)")
+ help="Seconds to wait for nodes to start (default: %default)")
parser.add_option(
"-k", "--key-pair",
help="Key pair to use on instances")
@@ -68,7 +71,7 @@ def parse_args():
help="SSH private key file to use for logging into instances")
parser.add_option(
"-t", "--instance-type", default="m1.large",
- help="Type of instance to launch (default: m1.large). " +
+ help="Type of instance to launch (default: %default). " +
"WARNING: must be 64-bit; small instances won't work")
parser.add_option(
"-m", "--master-instance-type", default="",
@@ -83,15 +86,15 @@ def parse_args():
"between zones applies)")
parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use")
parser.add_option(
- "-v", "--spark-version", default="1.0.0",
- help="Version of Spark to use: 'X.Y.Z' or a specific git hash")
+ "-v", "--spark-version", default=DEFAULT_SPARK_VERSION,
+ help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)")
parser.add_option(
"--spark-git-repo",
default="https://github.com/apache/spark",
help="Github repo from which to checkout supplied commit hash")
parser.add_option(
"--hadoop-major-version", default="1",
- help="Major version of Hadoop (default: 1)")
+ help="Major version of Hadoop (default: %default)")
parser.add_option(
"-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
help="Use SSH dynamic port forwarding to create a SOCKS proxy at " +
@@ -115,21 +118,21 @@ def parse_args():
"Only support up to 8 EBS volumes.")
parser.add_option(
"--swap", metavar="SWAP", type="int", default=1024,
- help="Swap space to set up per node, in MB (default: 1024)")
+ help="Swap space to set up per node, in MB (default: %default)")
parser.add_option(
"--spot-price", metavar="PRICE", type="float",
help="If specified, launch slaves as spot instances with the given " +
"maximum price (in dollars)")
parser.add_option(
"--ganglia", action="store_true", default=True,
- help="Setup Ganglia monitoring on cluster (default: on). NOTE: " +
+ help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " +
"the Ganglia page will be publicly accessible")
parser.add_option(
"--no-ganglia", action="store_false", dest="ganglia",
help="Disable Ganglia monitoring for the cluster")
parser.add_option(
"-u", "--user", default="root",
- help="The SSH user you want to connect as (default: root)")
+ help="The SSH user you want to connect as (default: %default)")
parser.add_option(
"--delete-groups", action="store_true", default=False,
help="When destroying a cluster, delete the security groups that were created.")
@@ -138,7 +141,7 @@ def parse_args():
help="Launch fresh slaves, but use an existing stopped master if possible")
parser.add_option(
"--worker-instances", type="int", default=1,
- help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: 1)")
+ help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)")
parser.add_option(
"--master-opts", type="string", default="",
help="Extra options to give to master through SPARK_MASTER_OPTS variable " +
@@ -151,7 +154,7 @@ def parse_args():
help="Use this prefix for the security group rather than the cluster name.")
parser.add_option(
"--authorized-address", type="string", default="0.0.0.0/0",
- help="Address to authorize on created security groups (default: 0.0.0.0/0)")
+ help="Address to authorize on created security groups (default: %default)")
parser.add_option(
"--additional-security-group", type="string", default="",
help="Additional security group to place the machines in")
@@ -342,7 +345,6 @@ def launch_cluster(conn, opts, cluster_name):
if opts.ami is None:
opts.ami = get_spark_ami(opts)
-
additional_groups = []
if opts.additional_security_group:
additional_groups = [sg
@@ -363,7 +365,7 @@ def launch_cluster(conn, opts, cluster_name):
for i in range(opts.ebs_vol_num):
device = EBSBlockDeviceType()
device.size = opts.ebs_vol_size
- device.volume_type=opts.ebs_vol_type
+ device.volume_type = opts.ebs_vol_type
device.delete_on_termination = True
block_map["/dev/sd" + chr(ord('s') + i)] = device
@@ -495,6 +497,7 @@ def launch_cluster(conn, opts, cluster_name):
# Return all the instances
return (master_nodes, slave_nodes)
+
def tag_instance(instance, name):
for i in range(0, 5):
try:
@@ -507,9 +510,12 @@ def tag_instance(instance, name):
# Get the EC2 instances in an existing cluster if available.
# Returns a tuple of lists of EC2 instance objects for the masters and slaves
+
+
def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
print "Searching for existing cluster " + cluster_name + "..."
- # Search all the spot instance requests, and copy any tags from the spot instance request to the cluster.
+ # Search all the spot instance requests, and copy any tags from the spot
+ # instance request to the cluster.
spot_instance_requests = conn.get_all_spot_instance_requests()
for req in spot_instance_requests:
if req.state != u'active':
@@ -520,7 +526,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
for res in reservations:
active = [i for i in res.instances if is_active(i)]
for instance in active:
- if (instance.tags.get(u'Name') == None):
+ if (instance.tags.get(u'Name') is None):
tag_instance(instance, name)
# Now proceed to detect master and slaves instances.
reservations = conn.get_all_instances()
@@ -540,13 +546,16 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
return (master_nodes, slave_nodes)
else:
if master_nodes == [] and slave_nodes != []:
- print >> sys.stderr, "ERROR: Could not find master in with name " + cluster_name + "-master"
+ print >> sys.stderr, "ERROR: Could not find master in with name " + \
+ cluster_name + "-master"
else:
print >> sys.stderr, "ERROR: Could not find any existing cluster"
sys.exit(1)
# Deploy configuration files and run setup scripts on a newly launched
# or started EC2 cluster.
+
+
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
master = master_nodes[0].public_dns_name
if deploy_ssh_key:
@@ -890,7 +899,8 @@ def real_main():
if opts.security_group_prefix is None:
group_names = [cluster_name + "-master", cluster_name + "-slaves"]
else:
- group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"]
+ group_names = [opts.security_group_prefix + "-master",
+ opts.security_group_prefix + "-slaves"]
attempt = 1
while attempt <= 3:
diff --git a/examples/pom.xml b/examples/pom.xml
index 9b12cb0c29c9f..3f46c40464d3b 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
index e4468e8bf1744..1f82e3f4cb18e 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
@@ -63,7 +63,7 @@ public static void main(String[] args) {
HashMap categoricalFeaturesInfo = new HashMap();
String impurity = "gini";
Integer maxDepth = 5;
- Integer maxBins = 100;
+ Integer maxBins = 32;
// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py
index e902ae29753c0..cfda8d8327aa3 100644
--- a/examples/src/main/python/avro_inputformat.py
+++ b/examples/src/main/python/avro_inputformat.py
@@ -23,7 +23,8 @@
Read data file users.avro in local Spark distro:
$ cd $SPARK_HOME
-$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \
+$ ./bin/spark-submit --driver-class-path /path/to/example/jar \
+> ./examples/src/main/python/avro_inputformat.py \
> examples/src/main/resources/users.avro
{u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]}
{u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []}
@@ -40,7 +41,8 @@
]
}
-$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \
+$ ./bin/spark-submit --driver-class-path /path/to/example/jar \
+> ./examples/src/main/python/avro_inputformat.py \
> examples/src/main/resources/users.avro examples/src/main/resources/user.avsc
{u'favorite_color': None, u'name': u'Alyssa'}
{u'favorite_color': u'red', u'name': u'Ben'}
@@ -51,8 +53,10 @@
Usage: avro_inputformat [reader_schema_file]
Run with example jar:
- ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/avro_inputformat.py [reader_schema_file]
- Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file].
+ ./bin/spark-submit --driver-class-path /path/to/example/jar \
+ /path/to/examples/avro_inputformat.py [reader_schema_file]
+ Assumes you have Avro data stored in . Reader schema can be optionally specified
+ in [reader_schema_file].
"""
exit(-1)
@@ -62,9 +66,10 @@
conf = None
if len(sys.argv) == 3:
schema_rdd = sc.textFile(sys.argv[2], 1).collect()
- conf = {"avro.schema.input.key" : reduce(lambda x, y: x+y, schema_rdd)}
+ conf = {"avro.schema.input.key": reduce(lambda x, y: x + y, schema_rdd)}
- avro_rdd = sc.newAPIHadoopFile(path,
+ avro_rdd = sc.newAPIHadoopFile(
+ path,
"org.apache.avro.mapreduce.AvroKeyInputFormat",
"org.apache.avro.mapred.AvroKey",
"org.apache.hadoop.io.NullWritable",
diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py
index e4a897f61e39d..05f34b74df45a 100644
--- a/examples/src/main/python/cassandra_inputformat.py
+++ b/examples/src/main/python/cassandra_inputformat.py
@@ -51,7 +51,8 @@
Usage: cassandra_inputformat
Run with example jar:
- ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_inputformat.py
+ ./bin/spark-submit --driver-class-path /path/to/example/jar \
+ /path/to/examples/cassandra_inputformat.py
Assumes you have some data in Cassandra already, running on , in and
"""
exit(-1)
@@ -61,12 +62,12 @@
cf = sys.argv[3]
sc = SparkContext(appName="CassandraInputFormat")
- conf = {"cassandra.input.thrift.address":host,
- "cassandra.input.thrift.port":"9160",
- "cassandra.input.keyspace":keyspace,
- "cassandra.input.columnfamily":cf,
- "cassandra.input.partitioner.class":"Murmur3Partitioner",
- "cassandra.input.page.row.size":"3"}
+ conf = {"cassandra.input.thrift.address": host,
+ "cassandra.input.thrift.port": "9160",
+ "cassandra.input.keyspace": keyspace,
+ "cassandra.input.columnfamily": cf,
+ "cassandra.input.partitioner.class": "Murmur3Partitioner",
+ "cassandra.input.page.row.size": "3"}
cass_rdd = sc.newAPIHadoopRDD(
"org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat",
"java.util.Map",
diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py
index 836c35b5c6794..d144539e58b8f 100644
--- a/examples/src/main/python/cassandra_outputformat.py
+++ b/examples/src/main/python/cassandra_outputformat.py
@@ -50,7 +50,8 @@
Usage: cassandra_outputformat
Run with example jar:
- ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_outputformat.py
+ ./bin/spark-submit --driver-class-path /path/to/example/jar \
+ /path/to/examples/cassandra_outputformat.py
Assumes you have created the following table in Cassandra already,
running on , in .
@@ -67,16 +68,16 @@
cf = sys.argv[3]
sc = SparkContext(appName="CassandraOutputFormat")
- conf = {"cassandra.output.thrift.address":host,
- "cassandra.output.thrift.port":"9160",
- "cassandra.output.keyspace":keyspace,
- "cassandra.output.partitioner.class":"Murmur3Partitioner",
- "cassandra.output.cql":"UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?",
- "mapreduce.output.basename":cf,
- "mapreduce.outputformat.class":"org.apache.cassandra.hadoop.cql3.CqlOutputFormat",
- "mapreduce.job.output.key.class":"java.util.Map",
- "mapreduce.job.output.value.class":"java.util.List"}
- key = {"user_id" : int(sys.argv[4])}
+ conf = {"cassandra.output.thrift.address": host,
+ "cassandra.output.thrift.port": "9160",
+ "cassandra.output.keyspace": keyspace,
+ "cassandra.output.partitioner.class": "Murmur3Partitioner",
+ "cassandra.output.cql": "UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?",
+ "mapreduce.output.basename": cf,
+ "mapreduce.outputformat.class": "org.apache.cassandra.hadoop.cql3.CqlOutputFormat",
+ "mapreduce.job.output.key.class": "java.util.Map",
+ "mapreduce.job.output.value.class": "java.util.List"}
+ key = {"user_id": int(sys.argv[4])}
sc.parallelize([(key, sys.argv[5:])]).saveAsNewAPIHadoopDataset(
conf=conf,
keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter",
diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py
index befacee0dea56..3b16010f1cb97 100644
--- a/examples/src/main/python/hbase_inputformat.py
+++ b/examples/src/main/python/hbase_inputformat.py
@@ -51,7 +51,8 @@
Usage: hbase_inputformat
Run with example jar:
- ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_inputformat.py
+ ./bin/spark-submit --driver-class-path /path/to/example/jar \
+ /path/to/examples/hbase_inputformat.py
Assumes you have some data in HBase already, running on , in
"""
exit(-1)
@@ -61,12 +62,15 @@
sc = SparkContext(appName="HBaseInputFormat")
conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table}
+ keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter"
+ valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter"
+
hbase_rdd = sc.newAPIHadoopRDD(
"org.apache.hadoop.hbase.mapreduce.TableInputFormat",
"org.apache.hadoop.hbase.io.ImmutableBytesWritable",
"org.apache.hadoop.hbase.client.Result",
- keyConverter="org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter",
- valueConverter="org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter",
+ keyConverter=keyConv,
+ valueConverter=valueConv,
conf=conf)
output = hbase_rdd.collect()
for (k, v) in output:
diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py
index 49bbc5aebdb0b..abb425b1f886a 100644
--- a/examples/src/main/python/hbase_outputformat.py
+++ b/examples/src/main/python/hbase_outputformat.py
@@ -44,8 +44,10 @@
Usage: hbase_outputformat
Run with example jar:
- ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_outputformat.py
- Assumes you have created with column family in HBase running on already
+ ./bin/spark-submit --driver-class-path /path/to/example/jar \
+ /path/to/examples/hbase_outputformat.py
+ Assumes you have created with column family in HBase
+ running on already
"""
exit(-1)
@@ -55,13 +57,15 @@
conf = {"hbase.zookeeper.quorum": host,
"hbase.mapred.outputtable": table,
- "mapreduce.outputformat.class" : "org.apache.hadoop.hbase.mapreduce.TableOutputFormat",
- "mapreduce.job.output.key.class" : "org.apache.hadoop.hbase.io.ImmutableBytesWritable",
- "mapreduce.job.output.value.class" : "org.apache.hadoop.io.Writable"}
+ "mapreduce.outputformat.class": "org.apache.hadoop.hbase.mapreduce.TableOutputFormat",
+ "mapreduce.job.output.key.class": "org.apache.hadoop.hbase.io.ImmutableBytesWritable",
+ "mapreduce.job.output.value.class": "org.apache.hadoop.io.Writable"}
+ keyConv = "org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter"
+ valueConv = "org.apache.spark.examples.pythonconverters.StringListToPutConverter"
sc.parallelize([sys.argv[3:]]).map(lambda x: (x[0], x)).saveAsNewAPIHadoopDataset(
conf=conf,
- keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter",
- valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter")
+ keyConverter=keyConv,
+ valueConverter=valueConv)
sc.stop()
diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py
index 6b16a56e44af7..4218eca822a99 100755
--- a/examples/src/main/python/mllib/correlations.py
+++ b/examples/src/main/python/mllib/correlations.py
@@ -28,7 +28,7 @@
if __name__ == "__main__":
- if len(sys.argv) not in [1,2]:
+ if len(sys.argv) not in [1, 2]:
print >> sys.stderr, "Usage: correlations ()"
exit(-1)
sc = SparkContext(appName="PythonCorrelations")
diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py
index 6e4a4a0cb6be0..61ea4e06ecf3a 100755
--- a/examples/src/main/python/mllib/decision_tree_runner.py
+++ b/examples/src/main/python/mllib/decision_tree_runner.py
@@ -21,7 +21,9 @@
This example requires NumPy (http://www.numpy.org/).
"""
-import numpy, os, sys
+import numpy
+import os
+import sys
from operator import add
@@ -127,7 +129,7 @@ def usage():
(reindexedData, origToNewLabels) = reindexClassLabels(points)
# Train a classifier.
- categoricalFeaturesInfo={} # no categorical features
+ categoricalFeaturesInfo = {} # no categorical features
model = DecisionTree.trainClassifier(reindexedData, numClasses=2,
categoricalFeaturesInfo=categoricalFeaturesInfo)
# Print learned tree and stats.
diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py
index b388d8d83fb86..1e8892741e714 100755
--- a/examples/src/main/python/mllib/random_rdd_generation.py
+++ b/examples/src/main/python/mllib/random_rdd_generation.py
@@ -32,8 +32,8 @@
sc = SparkContext(appName="PythonRandomRDDGeneration")
- numExamples = 10000 # number of examples to generate
- fraction = 0.1 # fraction of data to sample
+ numExamples = 10000 # number of examples to generate
+ fraction = 0.1 # fraction of data to sample
# Example: RandomRDDs.normalRDD
normalRDD = RandomRDDs.normalRDD(sc, numExamples)
@@ -45,7 +45,7 @@
print
# Example: RandomRDDs.normalVectorRDD
- normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2)
+ normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2)
print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count()
print ' First 5 samples:'
for sample in normalVectorRDD.take(5):
diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py
index ec64a5978c672..92af3af5ebd1e 100755
--- a/examples/src/main/python/mllib/sampled_rdds.py
+++ b/examples/src/main/python/mllib/sampled_rdds.py
@@ -36,7 +36,7 @@
sc = SparkContext(appName="PythonSampledRDDs")
- fraction = 0.1 # fraction of data to sample
+ fraction = 0.1 # fraction of data to sample
examples = MLUtils.loadLibSVMFile(sc, datapath)
numExamples = examples.count()
@@ -49,9 +49,9 @@
expectedSampleSize = int(numExamples * fraction)
print 'Sampling RDD using fraction %g. Expected sample size = %d.' \
% (fraction, expectedSampleSize)
- sampledRDD = examples.sample(withReplacement = True, fraction = fraction)
+ sampledRDD = examples.sample(withReplacement=True, fraction=fraction)
print ' RDD.sample(): sample has %d examples' % sampledRDD.count()
- sampledArray = examples.takeSample(withReplacement = True, num = expectedSampleSize)
+ sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize)
print ' RDD.takeSample(): sample has %d examples' % len(sampledArray)
print
@@ -66,7 +66,7 @@
fractions = {}
for k in keyCountsA.keys():
fractions[k] = fraction
- sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = True, fractions = fractions)
+ sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions)
keyCountsB = sampledByKeyRDD.countByKey()
sizeB = sum(keyCountsB.values())
print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \
diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py
index fc37459dc74aa..ee9036adfa281 100755
--- a/examples/src/main/python/pi.py
+++ b/examples/src/main/python/pi.py
@@ -35,7 +35,7 @@ def f(_):
y = random() * 2 - 1
return 1 if x ** 2 + y ** 2 < 1 else 0
- count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add)
+ count = sc.parallelize(xrange(1, n + 1), slices).map(f).reduce(add)
print "Pi is roughly %f" % (4.0 * count / n)
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
index bdc8fa7f99f2e..e809a65b79975 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
@@ -20,7 +20,7 @@ package org.apache.spark.examples.graphx
import org.apache.spark.SparkContext._
import org.apache.spark._
import org.apache.spark.graphx._
-import org.apache.spark.examples.graphx.Analytics
+
/**
* Uses GraphX to run PageRank on a LiveJournal social network graph. Download the dataset from
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index cf3d2cca81ff6..72c3ab475b61f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -52,9 +52,9 @@ object DecisionTreeRunner {
input: String = null,
dataFormat: String = "libsvm",
algo: Algo = Classification,
- maxDepth: Int = 4,
+ maxDepth: Int = 5,
impurity: ImpurityType = Gini,
- maxBins: Int = 100,
+ maxBins: Int = 32,
fracTest: Double = 0.2)
def main(args: Array[String]) {
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index b345276b08ba3..ac291bd4fde20 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index f71f6b6c4f931..7d31e32283d88 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 4e2275ab238f7..2067c473f0e3f 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index dc48a08c93de2..371f1f1e9d39a 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index b93ad016f84f0..1d7dd49d15c22 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index 22c1fff23d9a2..7e48968feb3bc 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index 5308bb4e440ea..8658ecf5abfab 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index a54b34235dfb4..560244ad93369 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index a5b162a0482e4..71a078d58a8d8 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 6dd52fc618b1e..3f49b1d63b6e1 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/make-distribution.sh b/make-distribution.sh
index f030d3f430581..9b012b9222db4 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -28,7 +28,7 @@ set -o pipefail
set -e
# Figure out where the Spark framework is installed
-FWDIR="$(cd `dirname $0`; pwd)"
+FWDIR="$(cd "`dirname "$0"`"; pwd)"
DISTDIR="$FWDIR/dist"
SPARK_TACHYON=false
@@ -50,7 +50,8 @@ while (( "$#" )); do
case $1 in
--hadoop)
echo "Error: '--hadoop' is no longer supported:"
- echo "Error: use Maven options -Phadoop.version and -Pyarn.version"
+ echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead."
+ echo "Error: Related profiles include hadoop-0.23, hdaoop-2.2, hadoop-2.3 and hadoop-2.4."
exit_with_usage
;;
--with-yarn)
diff --git a/mllib/pom.xml b/mllib/pom.xml
index c7a1e2ae75c84..a5eeef88e9d62 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 5cdd258f6c20b..d1309b2b20f54 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -28,8 +28,9 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impl._
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
+import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -65,36 +66,41 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
logDebug("algo = " + strategy.algo)
+ logDebug("maxBins = " + metadata.maxBins)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
- val numBins = bins(0).length
timer.stop("findSplitsBins")
- logDebug("numBins = " + numBins)
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
.persist(StorageLevel.MEMORY_AND_DISK)
- val numFeatures = metadata.numFeatures
// depth of the decision tree
val maxDepth = strategy.maxDepth
- // the max number of nodes possible given the depth of the tree
- val maxNumNodes = (2 << maxDepth) - 1
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+ // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
+ val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
// Initialize an array to hold parent impurity calculations for each node.
- val parentImpurities = new Array[Double](maxNumNodes)
+ val parentImpurities = new Array[Double](maxNumNodesPlus1)
// dummy value for top node (updated during first split calculation)
- val nodes = new Array[Node](maxNumNodes)
+ val nodes = new Array[Node](maxNumNodesPlus1)
// Calculate level for single group construction
// Max memory usage for aggregates
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
- val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins)
+ // TODO: Calculate memory usage more precisely.
+ val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -124,26 +130,29 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Find best split for all nodes at a level.
timer.start("findBestSplits")
- val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
- metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
+ val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
+ DecisionTree.findBestSplits(treeInput, parentImpurities,
+ metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")
- val levelNodeIndexOffset = (1 << level) - 1
+ val levelNodeIndexOffset = Node.startIndexInLevel(level)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
val nodeIndex = levelNodeIndexOffset + index
- val isLeftChild = level != 0 && nodeIndex % 2 == 1
- val parentNodeIndex = if (isLeftChild) { // -1 for root node
- (nodeIndex - 1) / 2
- } else {
- (nodeIndex - 2) / 2
- }
+
// Extract info for this node (index) at the current level.
timer.start("extractNodeInfo")
- extractNodeInfo(nodeSplitStats, level, index, nodes)
+ val split = nodeSplitStats._1
+ val stats = nodeSplitStats._2
+ val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
+ val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+ logDebug("Node = " + node)
+ nodes(nodeIndex) = node
timer.stop("extractNodeInfo")
+
if (level != 0) {
// Set parent.
- if (isLeftChild) {
+ val parentNodeIndex = Node.parentIndex(nodeIndex)
+ if (Node.isLeftChild(nodeIndex)) {
nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
} else {
nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
@@ -151,11 +160,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
// Extract info for nodes at the next lower level.
timer.start("extractInfoForLowerLevels")
- extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities)
+ if (level < maxDepth) {
+ val leftChildIndex = Node.leftChildIndex(nodeIndex)
+ val leftImpurity = stats.leftImpurity
+ logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity)
+ parentImpurities(leftChildIndex) = leftImpurity
+
+ val rightChildIndex = Node.rightChildIndex(nodeIndex)
+ val rightImpurity = stats.rightImpurity
+ logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity)
+ parentImpurities(rightChildIndex) = rightImpurity
+ }
timer.stop("extractInfoForLowerLevels")
- logDebug("final best split = " + nodeSplitStats._1)
+ logDebug("final best split = " + split)
}
- require((1 << level) == splitsStatsForLevel.length)
+ require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
@@ -171,7 +190,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("#####################################")
// Initialize the top or root node of the tree.
- val topNode = nodes(0)
+ val topNode = nodes(1)
// Build the full tree using the node info calculated in the level-wise best split calculations.
topNode.build(nodes)
@@ -183,47 +202,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
new DecisionTreeModel(topNode, strategy.algo)
}
- /**
- * Extract the decision tree node information for the given tree level and node index
- */
- private def extractNodeInfo(
- nodeSplitStats: (Split, InformationGainStats),
- level: Int,
- index: Int,
- nodes: Array[Node]): Unit = {
- val split = nodeSplitStats._1
- val stats = nodeSplitStats._2
- val nodeIndex = (1 << level) - 1 + index
- val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
- val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
- logDebug("Node = " + node)
- nodes(nodeIndex) = node
- }
-
- /**
- * Extract the decision tree node information for the children of the node
- */
- private def extractInfoForLowerLevels(
- level: Int,
- index: Int,
- maxDepth: Int,
- nodeSplitStats: (Split, InformationGainStats),
- parentImpurities: Array[Double]): Unit = {
-
- if (level >= maxDepth) {
- return
- }
-
- val leftNodeIndex = (2 << level) - 1 + 2 * index
- val leftImpurity = nodeSplitStats._2.leftImpurity
- logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
- parentImpurities(leftNodeIndex) = leftImpurity
-
- val rightNodeIndex = leftNodeIndex + 1
- val rightImpurity = nodeSplitStats._2.rightImpurity
- logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
- parentImpurities(rightNodeIndex) = rightImpurity
- }
}
object DecisionTree extends Serializable with Logging {
@@ -352,9 +330,9 @@ object DecisionTree extends Serializable with Logging {
* Supported values: "gini" (recommended) or "entropy".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 4)
+ * (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
- * (suggested value: 100)
+ * (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
def trainClassifier(
@@ -396,9 +374,9 @@ object DecisionTree extends Serializable with Logging {
* Supported values: "variance".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 4)
+ * (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
- * (suggested value: 100)
+ * (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
def trainRegressor(
@@ -425,9 +403,6 @@ object DecisionTree extends Serializable with Logging {
impurity, maxDepth, maxBins)
}
-
- private val InvalidBinIndex = -1
-
/**
* Returns an array of optimal splits for all nodes at a given level. Splits the task into
* multiple groups if the level-wise training task could lead to memory overflow.
@@ -436,12 +411,12 @@ object DecisionTree extends Serializable with Logging {
* @param parentImpurities Impurities for all parent nodes for the current level
* @param metadata Learning and dataset metadata
* @param level Level of the tree
- * @param splits possible splits for all features
- * @param bins possible bins for all features
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @return array (over nodes) of splits with best split for each node at a given level.
*/
- protected[tree] def findBestSplits(
+ private[tree] def findBestSplits(
input: RDD[TreePoint],
parentImpurities: Array[Double],
metadata: DecisionTreeMetadata,
@@ -474,6 +449,138 @@ object DecisionTree extends Serializable with Logging {
}
}
+ /**
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a node
+ * at the current level being trained; that node's index is returned.
+ *
+ * @param node Node in tree from which to classify the given data point.
+ * @param binnedFeatures Binned feature vector for data point.
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * set of nodes in a (level, group).
+ */
+ private def predictNodeIndex(
+ node: Node,
+ binnedFeatures: Array[Int],
+ bins: Array[Array[Bin]],
+ unorderedFeatures: Set[Int]): Int = {
+ if (node.isLeaf) {
+ node.id
+ } else {
+ val featureIndex = node.split.get.feature
+ val splitLeft = node.split.get.featureType match {
+ case Continuous => {
+ val binIndex = binnedFeatures(featureIndex)
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+ // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
+ // We do not need to check lowSplit since bins are separated by splits.
+ featureValueUpperBound <= node.split.get.threshold
+ }
+ case Categorical => {
+ val featureValue = binnedFeatures(featureIndex)
+ node.split.get.categories.contains(featureValue)
+ }
+ case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
+ }
+ if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
+ // Return index from next layer of nodes to train
+ if (splitLeft) {
+ Node.leftChildIndex(node.id)
+ } else {
+ Node.rightChildIndex(node.id)
+ }
+ } else {
+ if (splitLeft) {
+ predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
+ } else {
+ predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
+ *
+ * For ordered features, a single bin is updated.
+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
+ * for each subset is updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+ * @param unorderedFeatures Set of indices of unordered features.
+ */
+ private def mixedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ nodeIndex: Int,
+ bins: Array[Array[Bin]],
+ unorderedFeatures: Set[Int]): Unit = {
+ // Iterate over all features.
+ val numFeatures = treePoint.binnedFeatures.size
+ val nodeOffset = agg.getNodeOffset(nodeIndex)
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ if (unorderedFeatures.contains(featureIndex)) {
+ // Unordered feature
+ val featureValue = treePoint.binnedFeatures(featureIndex)
+ val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
+ agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+ // Update the left or right bin for each split.
+ val numSplits = agg.numSplits(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
+ agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label)
+ } else {
+ agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label)
+ }
+ splitIndex += 1
+ }
+ } else {
+ // Ordered feature
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label)
+ }
+ featureIndex += 1
+ }
+ }
+
+ /**
+ * Helper for binSeqOp, for regression and for classification with only ordered features.
+ *
+ * For each feature, the sufficient statistics of one bin are updated.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+ * @return agg
+ */
+ private def orderedBinSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint,
+ nodeIndex: Int): Unit = {
+ val label = treePoint.label
+ val nodeOffset = agg.getNodeOffset(nodeIndex)
+ // Iterate over all features.
+ val numFeatures = agg.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label)
+ featureIndex += 1
+ }
+ }
+
/**
* Returns an array of optimal splits for a group of nodes at a given level
*
@@ -481,8 +588,9 @@ object DecisionTree extends Serializable with Logging {
* @param parentImpurities Impurities for all parent nodes for the current level
* @param metadata Learning and dataset metadata
* @param level Level of the tree
- * @param splits possible splits for all features
- * @param bins possible bins for all features, indexed as (numFeatures)(numBins)
+ * @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param numGroups total number of node groups at the current level. Default value is set to 1.
* @param groupIndex index of the node group being processed. Default value is set to 0.
* @return array of splits with best splits for all nodes at a given level.
@@ -527,88 +635,22 @@ object DecisionTree extends Serializable with Logging {
// numNodes: Number of nodes in this (level of tree, group),
// where nodes at deeper (larger) levels may be divided into groups.
- val numNodes = (1 << level) / numGroups
+ val numNodes = Node.maxNodesInLevel(level) / numGroups
logDebug("numNodes = " + numNodes)
- // Find the number of features by looking at the first sample.
- val numFeatures = metadata.numFeatures
- logDebug("numFeatures = " + numFeatures)
-
- // numBins: Number of bins = 1 + number of possible splits
- val numBins = bins(0).length
- logDebug("numBins = " + numBins)
-
- val numClasses = metadata.numClasses
- logDebug("numClasses = " + numClasses)
-
- val isMulticlass = metadata.isMulticlass
- logDebug("isMulticlass = " + isMulticlass)
-
- val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures
- logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures)
+ logDebug("numFeatures = " + metadata.numFeatures)
+ logDebug("numClasses = " + metadata.numClasses)
+ logDebug("isMulticlass = " + metadata.isMulticlass)
+ logDebug("isMulticlassWithCategoricalFeatures = " +
+ metadata.isMulticlassWithCategoricalFeatures)
// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex
- /**
- * Get the node index corresponding to this data point.
- * This function mimics prediction, passing an example from the root node down to a node
- * at the current level being trained; that node's index is returned.
- *
- * @return Leaf index if the data point reaches a leaf.
- * Otherwise, last node reachable in tree matching this example.
- */
- def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
- if (node.isLeaf) {
- node.id
- } else {
- val featureIndex = node.split.get.feature
- val splitLeft = node.split.get.featureType match {
- case Continuous => {
- val binIndex = binnedFeatures(featureIndex)
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
- // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
- // We do not need to check lowSplit since bins are separated by splits.
- featureValueUpperBound <= node.split.get.threshold
- }
- case Categorical => {
- val featureValue = if (metadata.isUnordered(featureIndex)) {
- binnedFeatures(featureIndex)
- } else {
- val binIndex = binnedFeatures(featureIndex)
- bins(featureIndex)(binIndex).category
- }
- node.split.get.categories.contains(featureValue)
- }
- case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
- }
- if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
- // Return index from next layer of nodes to train
- if (splitLeft) {
- node.id * 2 + 1 // left
- } else {
- node.id * 2 + 2 // right
- }
- } else {
- if (splitLeft) {
- predictNodeIndex(node.leftNode.get, binnedFeatures)
- } else {
- predictNodeIndex(node.rightNode.get, binnedFeatures)
- }
- }
- }
- }
-
- def nodeIndexToLevel(idx: Int): Int = {
- if (idx == 0) {
- 0
- } else {
- math.floor(math.log(idx) / math.log(2)).toInt
- }
- }
-
- // Used for treePointToNodeIndex
- val levelOffset = (1 << level) - 1
+ // Used for treePointToNodeIndex to get an index for this (level, group).
+ // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level.
+ // - groupShift corrects for groups in this level before the current group.
+ val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift
/**
* Find the node index for the given example.
@@ -619,661 +661,254 @@ object DecisionTree extends Serializable with Logging {
if (level == 0) {
0
} else {
- val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures)
- // Get index for this (level, group).
- globalNodeIndex - levelOffset - groupShift
- }
- }
-
- /**
- * Increment aggregate in location for (node, feature, bin, label).
- *
- * @param treePoint Data point being aggregated.
- * @param agg Array storing aggregate calculation, of size:
- * numClasses * numBins * numFeatures * numNodes.
- * Indexed by (node, feature, bin, label) where label is the least significant bit.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- */
- def updateBinForOrderedFeature(
- treePoint: TreePoint,
- agg: Array[Double],
- nodeIndex: Int,
- featureIndex: Int): Unit = {
- // Update the left or right count for one bin.
- val aggIndex =
- numClasses * numBins * numFeatures * nodeIndex +
- numClasses * numBins * featureIndex +
- numClasses * treePoint.binnedFeatures(featureIndex) +
- treePoint.label.toInt
- agg(aggIndex) += 1
- }
-
- /**
- * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label),
- * where [bins] ranges over all bins.
- * Updates left or right side of aggregate depending on split.
- *
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- * @param treePoint Data point being aggregated.
- * @param agg Indexed by (left/right, node, feature, bin, label)
- * where label is the least significant bit.
- * The left/right specifier is a 0/1 index indicating left/right child info.
- * @param rightChildShift Offset for right side of agg.
- */
- def updateBinForUnorderedFeature(
- nodeIndex: Int,
- featureIndex: Int,
- treePoint: TreePoint,
- agg: Array[Double],
- rightChildShift: Int): Unit = {
- val featureValue = treePoint.binnedFeatures(featureIndex)
- // Update the left or right count for one bin.
- val aggShift =
- numClasses * numBins * numFeatures * nodeIndex +
- numClasses * numBins * featureIndex +
- treePoint.label.toInt
- // Find all matching bins and increment their values
- val featureCategories = metadata.featureArity(featureIndex)
- val numCategoricalBins = (1 << featureCategories - 1) - 1
- var binIndex = 0
- while (binIndex < numCategoricalBins) {
- val aggIndex = aggShift + binIndex * numClasses
- if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
- agg(aggIndex) += 1
- } else {
- agg(rightChildShift + aggIndex) += 1
- }
- binIndex += 1
- }
- }
-
- /**
- * Helper for binSeqOp.
- *
- * @param agg Array storing aggregate calculation, of size:
- * numClasses * numBins * numFeatures * numNodes.
- * Indexed by (node, feature, bin, label) where label is the least significant bit.
- * @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- */
- def binaryOrNotCategoricalBinSeqOp(
- agg: Array[Double],
- treePoint: TreePoint,
- nodeIndex: Int): Unit = {
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
- featureIndex += 1
- }
- }
-
- val rightChildShift = numClasses * numBins * numFeatures * numNodes
-
- /**
- * Helper for binSeqOp.
- *
- * @param agg Array storing aggregate calculation.
- * For ordered features, this is of size:
- * numClasses * numBins * numFeatures * numNodes.
- * For unordered features, this is of size:
- * 2 * numClasses * numBins * numFeatures * numNodes.
- * @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- */
- def multiclassWithCategoricalBinSeqOp(
- agg: Array[Double],
- treePoint: TreePoint,
- nodeIndex: Int): Unit = {
- val label = treePoint.label
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- if (metadata.isUnordered(featureIndex)) {
- updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift)
- } else {
- updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
- }
- featureIndex += 1
- }
- }
-
- /**
- * Performs a sequential aggregation over a partition for regression.
- * For l nodes, k features,
- * the count, sum, sum of squares of one of the p bins is incremented.
- *
- * @param agg Array storing aggregate calculation, updated by this function.
- * Size: 3 * numBins * numFeatures * numNodes
- * @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- * @return agg
- */
- def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
- val label = treePoint.label
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- // Update count, sum, and sum^2 for one bin.
- val binIndex = treePoint.binnedFeatures(featureIndex)
- val aggIndex =
- 3 * numBins * numFeatures * nodeIndex +
- 3 * numBins * featureIndex +
- 3 * binIndex
- agg(aggIndex) += 1
- agg(aggIndex + 1) += label
- agg(aggIndex + 2) += label * label
- featureIndex += 1
+ val globalNodeIndex =
+ predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
+ globalNodeIndex - globalNodeIndexOffset
}
}
/**
* Performs a sequential aggregation over a partition.
- * For l nodes, k features,
- * For classification:
- * Either the left count or the right count of one of the bins is
- * incremented based upon whether the feature is classified as 0 or 1.
- * For regression:
- * The count, sum, sum of squares of one of the bins is incremented.
*
- * @param agg Array storing aggregate calculation, updated by this function.
- * Size for classification:
- * numClasses * numBins * numFeatures * numNodes for ordered features, or
- * 2 * numClasses * numBins * numFeatures * numNodes for unordered features.
- * Size for regression:
- * 3 * numBins * numFeatures * numNodes.
+ * Each data point contributes to one node. For each feature,
+ * the aggregate sufficient statistics are updated for the relevant bins.
+ *
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (node, feature, bin).
* @param treePoint Data point being aggregated.
* @return agg
*/
- def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = {
+ def binSeqOp(
+ agg: DTStatsAggregator,
+ treePoint: TreePoint): DTStatsAggregator = {
val nodeIndex = treePointToNodeIndex(treePoint)
// If the example does not reach this level, then nodeIndex < 0.
// If the example reaches this level but is handled in a different group,
// then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
if (nodeIndex >= 0 && nodeIndex < numNodes) {
- if (metadata.isClassification) {
- if (isMulticlassWithCategoricalFeatures) {
- multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
- } else {
- binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex)
- }
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg, treePoint, nodeIndex)
} else {
- regressionBinSeqOp(agg, treePoint, nodeIndex)
+ mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures)
}
}
agg
}
- // Calculate bin aggregate length for classification or regression.
- val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins)
- logDebug("binAggregateLength = " + binAggregateLength)
-
- /**
- * Combines the aggregates from partitions.
- * @param agg1 Array containing aggregates from one or more partitions
- * @param agg2 Array containing aggregates from one or more partitions
- * @return Combined aggregate from agg1 and agg2
- */
- def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = {
- var index = 0
- val combinedAggregate = new Array[Double](binAggregateLength)
- while (index < binAggregateLength) {
- combinedAggregate(index) = agg1(index) + agg2(index)
- index += 1
- }
- combinedAggregate
- }
-
// Calculate bin aggregates.
timer.start("aggregation")
- val binAggregates = {
- input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
+ val binAggregates: DTStatsAggregator = {
+ val initAgg = new DTStatsAggregator(metadata, numNodes)
+ input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
}
timer.stop("aggregation")
- logDebug("binAggregates.length = " + binAggregates.length)
- /**
- * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
- * @param leftNodeAgg left node aggregates for this (feature, split)
- * @param rightNodeAgg right node aggregate for this (feature, split)
- * @param topImpurity impurity of the parent node
- * @return information gain and statistics for all splits
- */
- def calculateGainForSplit(
- leftNodeAgg: Array[Double],
- rightNodeAgg: Array[Double],
- topImpurity: Double): InformationGainStats = {
- if (metadata.isClassification) {
- val leftTotalCount = leftNodeAgg.sum
- val rightTotalCount = rightNodeAgg.sum
-
- val impurity = {
- if (level > 0) {
- topImpurity
- } else {
- // Calculate impurity for root node.
- val rootNodeCounts = new Array[Double](numClasses)
- var classIndex = 0
- while (classIndex < numClasses) {
- rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex)
- classIndex += 1
- }
- metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
- }
- }
-
- val totalCount = leftTotalCount + rightTotalCount
- if (totalCount == 0) {
- // Return arbitrary prediction.
- return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
- }
-
- // Sum of count for each label
- val leftrightNodeAgg: Array[Double] =
- leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) =>
- leftCount + rightCount
- }
-
- def indexOfLargestArrayElement(array: Array[Double]): Int = {
- val result = array.foldLeft(-1, Double.MinValue, 0) {
- case ((maxIndex, maxValue, currentIndex), currentValue) =>
- if (currentValue > maxValue) {
- (currentIndex, currentValue, currentIndex + 1)
- } else {
- (maxIndex, maxValue, currentIndex + 1)
- }
- }
- if (result._1 < 0) {
- throw new RuntimeException("DecisionTree internal error:" +
- " calculateGainForSplit failed in indexOfLargestArrayElement")
- }
- result._1
- }
-
- val predict = indexOfLargestArrayElement(leftrightNodeAgg)
- val prob = leftrightNodeAgg(predict) / totalCount
-
- val leftImpurity = if (leftTotalCount == 0) {
- topImpurity
- } else {
- metadata.impurity.calculate(leftNodeAgg, leftTotalCount)
- }
- val rightImpurity = if (rightTotalCount == 0) {
- topImpurity
- } else {
- metadata.impurity.calculate(rightNodeAgg, rightTotalCount)
- }
-
- val leftWeight = leftTotalCount / totalCount
- val rightWeight = rightTotalCount / totalCount
-
- val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
-
- } else {
- // Regression
-
- val leftCount = leftNodeAgg(0)
- val leftSum = leftNodeAgg(1)
- val leftSumSquares = leftNodeAgg(2)
+ // Calculate best splits for all nodes at a given level
+ timer.start("chooseSplits")
+ val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+ // Iterating over all nodes at this level
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
+ logDebug("node impurity = " + nodeImpurity)
+ bestSplits(nodeIndex) =
+ binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
+ logDebug("best split = " + bestSplits(nodeIndex)._1)
+ nodeIndex += 1
+ }
+ timer.stop("chooseSplits")
- val rightCount = rightNodeAgg(0)
- val rightSum = rightNodeAgg(1)
- val rightSumSquares = rightNodeAgg(2)
+ bestSplits
+ }
- val impurity = {
- if (level > 0) {
- topImpurity
- } else {
- // Calculate impurity for root node.
- val count = leftCount + rightCount
- val sum = leftSum + rightSum
- val sumSquares = leftSumSquares + rightSumSquares
- metadata.impurity.calculate(count, sum, sumSquares)
- }
- }
+ /**
+ * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+ * @param leftImpurityCalculator left node aggregates for this (feature, split)
+ * @param rightImpurityCalculator right node aggregate for this (feature, split)
+ * @param topImpurity impurity of the parent node
+ * @return information gain and statistics for all splits
+ */
+ private def calculateGainForSplit(
+ leftImpurityCalculator: ImpurityCalculator,
+ rightImpurityCalculator: ImpurityCalculator,
+ topImpurity: Double,
+ level: Int,
+ metadata: DecisionTreeMetadata): InformationGainStats = {
- if (leftCount == 0) {
- return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
- rightSum / rightCount)
- }
- if (rightCount == 0) {
- return new InformationGainStats(0, topImpurity, topImpurity,
- Double.MinValue, leftSum / leftCount)
- }
+ val leftCount = leftImpurityCalculator.count
+ val rightCount = rightImpurityCalculator.count
- val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares)
- val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares)
+ val totalCount = leftCount + rightCount
+ if (totalCount == 0) {
+ // Return arbitrary prediction.
+ return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+ }
- val leftWeight = leftCount.toDouble / (leftCount + rightCount)
- val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+ val parentNodeAgg = leftImpurityCalculator.copy
+ parentNodeAgg.add(rightImpurityCalculator)
+ // impurity of parent node
+ val impurity = if (level > 0) {
+ topImpurity
+ } else {
+ parentNodeAgg.calculate()
+ }
- val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ val predict = parentNodeAgg.predict
+ val prob = parentNodeAgg.prob(predict)
- val predict = (leftSum + rightSum) / (leftCount + rightCount)
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
- }
- }
+ val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
+ val rightImpurity = rightImpurityCalculator.calculate()
- /**
- * Extracts left and right split aggregates.
- * @param binData Aggregate array slice from getBinDataForNode.
- * For classification:
- * For unordered features, this is leftChildData ++ rightChildData,
- * each of which is indexed by (feature, split/bin, class),
- * with class being the least significant bit.
- * For ordered features, this is of size numClasses * numBins * numFeatures.
- * For regression:
- * This is of size 2 * numFeatures * numBins.
- * @return (leftNodeAgg, rightNodeAgg) pair of arrays.
- * For classification, each array is of size (numFeatures, (numBins - 1), numClasses).
- * For regression, each array is of size (numFeatures, (numBins - 1), 3).
- *
- */
- def extractLeftRightNodeAggregates(
- binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
-
-
- /**
- * The input binData is indexed as (feature, bin, class).
- * This computes cumulative sums over splits.
- * Each (feature, class) pair is handled separately.
- * Note: numSplits = numBins - 1.
- * @param leftNodeAgg Each (feature, class) slice is an array over splits.
- * Element i (i = 0, ..., numSplits - 2) is set to be
- * the cumulative sum (from left) over binData for bins 0, ..., i.
- * @param rightNodeAgg Each (feature, class) slice is an array over splits.
- * Element i (i = 1, ..., numSplits - 1) is set to be
- * the cumulative sum (from right) over binData for bins
- * numBins - 1, ..., numBins - 1 - i.
- */
- def findAggForOrderedFeatureClassification(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- featureIndex: Int) {
-
- // shift for this featureIndex
- val shift = numClasses * featureIndex * numBins
-
- var classIndex = 0
- while (classIndex < numClasses) {
- // left node aggregate for the lowest split
- leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex)
- // right node aggregate for the highest split
- rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
- = binData(shift + (numClasses * (numBins - 1)) + classIndex)
- classIndex += 1
- }
+ val leftWeight = leftCount / totalCount.toDouble
+ val rightWeight = rightCount / totalCount.toDouble
- // Iterate over all splits.
- var splitIndex = 1
- while (splitIndex < numBins - 1) {
- // calculating left node aggregate for a split as a sum of left node aggregate of a
- // lower split and the left bin aggregate of a bin where the split is a high split
- var innerClassIndex = 0
- while (innerClassIndex < numClasses) {
- leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
- = binData(shift + numClasses * splitIndex + innerClassIndex) +
- leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
- rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
- binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
- rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
- innerClassIndex += 1
- }
- splitIndex += 1
- }
- }
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- /**
- * Reshape binData for this feature.
- * Indexes binData as (feature, split, class) with class as the least significant bit.
- * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value
- */
- def findAggForUnorderedFeatureClassification(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- featureIndex: Int) {
-
- val rightChildShift = numClasses * numBins * numFeatures
- var splitIndex = 0
- while (splitIndex < numBins - 1) {
- var classIndex = 0
- while (classIndex < numClasses) {
- // shift for this featureIndex
- val shift = numClasses * featureIndex * numBins + splitIndex * numClasses
- val leftBinValue = binData(shift + classIndex)
- val rightBinValue = binData(rightChildShift + shift + classIndex)
- leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
- rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
- classIndex += 1
- }
- splitIndex += 1
- }
- }
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+ }
- def findAggForRegression(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- featureIndex: Int) {
-
- // shift for this featureIndex
- val shift = 3 * featureIndex * numBins
- // left node aggregate for the lowest split
- leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
- leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
- leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
-
- // right node aggregate for the highest split
- rightNodeAgg(featureIndex)(numBins - 2)(0) =
- binData(shift + (3 * (numBins - 1)))
- rightNodeAgg(featureIndex)(numBins - 2)(1) =
- binData(shift + (3 * (numBins - 1)) + 1)
- rightNodeAgg(featureIndex)(numBins - 2)(2) =
- binData(shift + (3 * (numBins - 1)) + 2)
-
- // Iterate over all splits.
- var splitIndex = 1
- while (splitIndex < numBins - 1) {
- var i = 0 // index for regression histograms
- while (i < 3) { // count, sum, sum^2
- // calculating left node aggregate for a split as a sum of left node aggregate of a
- // lower split and the left bin aggregate of a bin where the split is a high split
- leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) +
- leftNodeAgg(featureIndex)(splitIndex - 1)(i)
- // calculating right node aggregate for a split as a sum of right node aggregate of a
- // higher split and the right bin aggregate of a bin where the split is a low split
- rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) =
- binData(shift + (3 * (numBins - 1 - splitIndex) + i)) +
- rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i)
- i += 1
- }
- splitIndex += 1
- }
- }
+ /**
+ * Find the best split for a node.
+ * @param binAggregates Bin statistics.
+ * @param nodeIndex Index for node to split in this (level, group).
+ * @param nodeImpurity Impurity of the node (nodeIndex).
+ * @return tuple for best split: (Split, information gain)
+ */
+ private def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ nodeIndex: Int,
+ nodeImpurity: Double,
+ level: Int,
+ metadata: DecisionTreeMetadata,
+ splits: Array[Array[Split]]): (Split, InformationGainStats) = {
- if (metadata.isClassification) {
- // Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
- val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- if (metadata.isUnordered(featureIndex)) {
- findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- } else {
- findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
- }
- featureIndex += 1
- }
- (leftNodeAgg, rightNodeAgg)
- } else {
- // Regression
- // Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
- val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
- featureIndex += 1
- }
- (leftNodeAgg, rightNodeAgg)
- }
- }
+ logDebug("node impurity = " + nodeImpurity)
- /**
- * Calculates information gain for all nodes splits.
- */
- def calculateGainsForAllNodeSplits(
- leftNodeAgg: Array[Array[Array[Double]]],
- rightNodeAgg: Array[Array[Array[Double]]],
- nodeImpurity: Double): Array[Array[InformationGainStats]] = {
- val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
-
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
+ // For each (feature, split), calculate the gain, and select the best (feature, split).
+ Range(0, metadata.numFeatures).map { featureIndex =>
+ val numSplits = metadata.numSplits(featureIndex)
+ if (metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
var splitIndex = 0
- while (splitIndex < numSplitsForFeature) {
- gains(featureIndex)(splitIndex) =
- calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex),
- rightNodeAgg(featureIndex)(splitIndex), nodeImpurity)
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
- featureIndex += 1
- }
- gains
- }
-
- /**
- * Get the number of splits for a feature.
- */
- def getNumSplitsForFeature(featureIndex: Int): Int = {
- if (metadata.isContinuous(featureIndex)) {
- numBins - 1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { case splitIdx =>
+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ val gainStats =
+ calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ (splitIdx, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val (leftChildOffset, rightChildOffset) =
+ binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ val gainStats =
+ calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
- // Categorical feature
- val featureCategories = metadata.featureArity(featureIndex)
- if (metadata.isUnordered(featureIndex)) {
- (1 << featureCategories - 1) - 1
- } else {
- featureCategories
- }
- }
- }
-
- /**
- * Find the best split for a node.
- * @param binData Bin data slice for this node, given by getBinDataForNode.
- * @param nodeImpurity impurity of the top node
- * @return tuple of split and information gain
- */
- def binsToBestSplit(
- binData: Array[Double],
- nodeImpurity: Double): (Split, InformationGainStats) = {
-
- logDebug("node impurity = " + nodeImpurity)
-
- // Extract left right node aggregates.
- val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
-
- // Calculate gains for all splits.
- val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
-
- val (bestFeatureIndex, bestSplitIndex, gainStats) = {
- // Initialize with infeasible values.
- var bestFeatureIndex = Int.MinValue
- var bestSplitIndex = Int.MinValue
- var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
- // Iterate over features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- // Iterate over all splits.
- var splitIndex = 0
- val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
- while (splitIndex < numSplitsForFeature) {
- val gainStats = gains(featureIndex)(splitIndex)
- if (gainStats.gain > bestGainStats.gain) {
- bestGainStats = gainStats
- bestFeatureIndex = featureIndex
- bestSplitIndex = splitIndex
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
+ val numBins = metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = if (metadata.isMulticlass) {
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ Range(0, numBins).map { case featureValue =>
+ val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ categoryStats.calculate()
+ } else {
+ Double.MaxValue
}
- splitIndex += 1
+ (featureValue, centroid)
+ }
+ } else { // regression or binary classification
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the centroid of their corresponding labels.
+ Range(0, numBins).map { case featureValue =>
+ val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ categoryStats.predict
+ } else {
+ Double.MaxValue
+ }
+ (featureValue, centroid)
}
- featureIndex += 1
}
- (bestFeatureIndex, bestSplitIndex, bestGainStats)
- }
- logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex))
- logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
+ logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
- (splits(bestFeatureIndex)(bestSplitIndex), gainStats)
- }
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
- /**
- * Get bin data for one node.
- */
- def getBinDataForNode(node: Int): Array[Double] = {
- if (metadata.isClassification) {
- if (isMulticlassWithCategoricalFeatures) {
- val shift = numClasses * node * numBins * numFeatures
- val rightChildShift = numClasses * numBins * numFeatures * numNodes
- val binsForNode = {
- val leftChildData
- = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
- val rightChildData
- = binAggregates.slice(rightChildShift + shift,
- rightChildShift + shift + numClasses * numBins * numFeatures)
- leftChildData ++ rightChildData
- }
- binsForNode
- } else {
- val shift = numClasses * node * numBins * numFeatures
- val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
- binsForNode
+ logDebug("Sorted centroids for categorical variable = " +
+ categoriesSortedByCentroid.mkString(","))
+
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
}
- } else {
- // Regression
- val shift = 3 * node * numBins * numFeatures
- val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
- binsForNode
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ val gainStats =
+ calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
+ (bestFeatureSplit, bestFeatureGainStats)
}
- }
-
- // Calculate best splits for all nodes at a given level
- timer.start("chooseSplits")
- val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
- // Iterating over all nodes at this level
- var node = 0
- while (node < numNodes) {
- val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
- val binsForNode: Array[Double] = getBinDataForNode(node)
- logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
- val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
- logDebug("parent node impurity = " + parentNodeImpurity)
- bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
- node += 1
- }
- timer.stop("chooseSplits")
-
- bestSplits
+ }.maxBy(_._2.gain)
}
/**
* Get the number of values to be stored per node in the bin aggregates.
- *
- * @param numBins Number of bins = 1 + number of possible splits.
*/
- private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
+ private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = {
+ val totalBins = metadata.numBins.sum
if (metadata.isClassification) {
- if (metadata.isMulticlassWithCategoricalFeatures) {
- 2 * metadata.numClasses * numBins * metadata.numFeatures
- } else {
- metadata.numClasses * numBins * metadata.numFeatures
- }
+ metadata.numClasses * totalBins
} else {
- 3 * numBins * metadata.numFeatures
+ 3 * totalBins
}
}
@@ -1284,6 +919,7 @@ object DecisionTree extends Serializable with Logging {
* Continuous features:
* For each feature, there are numBins - 1 possible splits representing the possible binary
* decisions at each node in the tree.
+ * This finds locations (feature values) for splits using a subsample of the data.
*
* Categorical features:
* For each feature, there is 1 bin per split.
@@ -1292,7 +928,6 @@ object DecisionTree extends Serializable with Logging {
* For multiclass classification with a low-arity feature
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is split based on subsets of categories.
- * There are (1 << maxFeatureValue - 1) - 1 splits.
* (b) "ordered features"
* For regression and binary classification,
* and for multiclass classification with a high-arity feature,
@@ -1302,7 +937,7 @@ object DecisionTree extends Serializable with Logging {
* @param metadata Learning and dataset metadata
* @return A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
- * of size (numFeatures, numBins - 1).
+ * of size (numFeatures, numSplits).
* Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
* of size (numFeatures, numBins).
*/
@@ -1310,84 +945,80 @@ object DecisionTree extends Serializable with Logging {
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
- val count = input.count()
+ logDebug("isMulticlass = " + metadata.isMulticlass)
- // Find the number of features by looking at the first sample
- val numFeatures = input.take(1)(0).features.size
-
- val maxBins = metadata.maxBins
- val numBins = if (maxBins <= count) maxBins else count.toInt
- logDebug("numBins = " + numBins)
- val isMulticlass = metadata.isMulticlass
- logDebug("isMulticlass = " + isMulticlass)
-
- /*
- * Ensure numBins is always greater than the categories. For multiclass classification,
- * numBins should be greater than 2^(maxCategories - 1) - 1.
- * It's a limitation of the current implementation but a reasonable trade-off since features
- * with large number of categories get favored over continuous features.
- *
- * This needs to be checked here instead of in Strategy since numBins can be determined
- * by the number of training examples.
- * TODO: Allow this case, where we simply will know nothing about some categories.
- */
- if (metadata.featureArity.size > 0) {
- val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2
- require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
- "in categorical features")
- }
-
- // Calculate the number of sample for approximate quantile calculation.
- val requiredSamples = numBins*numBins
- val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
- logDebug("fraction of data used for calculating quantiles = " + fraction)
+ val numFeatures = metadata.numFeatures
- // sampled input for RDD calculation
- val sampledInput =
+ // Sample the input only if there are continuous features.
+ val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
+ val sampledInput = if (hasContinuousFeatures) {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ val fraction = if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
- val numSamples = sampledInput.length
-
- val stride: Double = numSamples.toDouble / numBins
- logDebug("stride = " + stride)
+ } else {
+ new Array[LabeledPoint](0)
+ }
metadata.quantileStrategy match {
case Sort =>
- val splits = Array.ofDim[Split](numFeatures, numBins - 1)
- val bins = Array.ofDim[Bin](numFeatures, numBins)
+ val splits = new Array[Array[Split]](numFeatures)
+ val bins = new Array[Array[Bin]](numFeatures)
// Find all splits.
-
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
- // Check whether the feature is continuous.
- val isFeatureContinuous = metadata.isContinuous(featureIndex)
- if (isFeatureContinuous) {
+ val numSplits = metadata.numSplits(featureIndex)
+ val numBins = metadata.numBins(featureIndex)
+ if (metadata.isContinuous(featureIndex)) {
+ val numSamples = sampledInput.length
+ splits(featureIndex) = new Array[Split](numSplits)
+ bins(featureIndex) = new Array[Bin](numBins)
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
- val stride: Double = numSamples.toDouble / numBins
+ val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
logDebug("stride = " + stride)
- for (index <- 0 until numBins - 1) {
- val sampleIndex = index * stride.toInt
+ for (splitIndex <- 0 until numSplits) {
+ val sampleIndex = splitIndex * stride.toInt
// Set threshold halfway in between 2 samples.
val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
- val split = new Split(featureIndex, threshold, Continuous, List())
- splits(featureIndex)(index) = split
+ splits(featureIndex)(splitIndex) =
+ new Split(featureIndex, threshold, Continuous, List())
}
- } else { // Categorical feature
- val featureCategories = metadata.featureArity(featureIndex)
-
- // Use different bin/split calculation strategy for categorical features in multiclass
- // classification that satisfy the space constraint.
+ bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
+ splits(featureIndex)(0), Continuous, Double.MinValue)
+ for (splitIndex <- 1 until numSplits) {
+ bins(featureIndex)(splitIndex) =
+ new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
+ Continuous, Double.MinValue)
+ }
+ bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
+ new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
+ } else {
+ // Categorical feature
+ val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
- // 2^(maxFeatureValue- 1) - 1 combinations
- var index = 0
- while (index < (1 << featureCategories - 1) - 1) {
- val categories: List[Double]
- = extractMultiClassCategories(index + 1, featureCategories)
- splits(featureIndex)(index)
- = new Split(featureIndex, Double.MinValue, Categorical, categories)
- bins(featureIndex)(index) = {
- if (index == 0) {
+ // TODO: The second half of the bins are unused. Actually, we could just use
+ // splits and not build bins for unordered features. That should be part of
+ // a later PR since it will require changing other code (using splits instead
+ // of bins in a few places).
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ splits(featureIndex) = new Array[Split](numSplits)
+ bins(featureIndex) = new Array[Bin](numBins)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val categories: List[Double] =
+ extractMultiClassCategories(splitIndex + 1, featureArity)
+ splits(featureIndex)(splitIndex) =
+ new Split(featureIndex, Double.MinValue, Categorical, categories)
+ bins(featureIndex)(splitIndex) = {
+ if (splitIndex == 0) {
new Bin(
new DummyCategoricalSplit(featureIndex, Categorical),
splits(featureIndex)(0),
@@ -1395,96 +1026,24 @@ object DecisionTree extends Serializable with Logging {
Double.MinValue)
} else {
new Bin(
- splits(featureIndex)(index - 1),
- splits(featureIndex)(index),
+ splits(featureIndex)(splitIndex - 1),
+ splits(featureIndex)(splitIndex),
Categorical,
Double.MinValue)
}
}
- index += 1
- }
- } else { // ordered feature
- /* For a given categorical feature, use a subsample of the data
- * to choose how to arrange possible splits.
- * This examines each category and computes a centroid.
- * These centroids are later used to sort the possible splits.
- * centroidForCategories is a mapping: category (for the given feature) --> centroid
- */
- val centroidForCategories = {
- if (isMulticlass) {
- // For categorical variables in multiclass classification,
- // each bin is a category. The bins are sorted and they
- // are ordered by calculating the impurity of their corresponding labels.
- sampledInput.map(lp => (lp.features(featureIndex), lp.label))
- .groupBy(_._1)
- .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
- .map(x => (x._1, x._2.values.toArray))
- .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum)))
- } else { // regression or binary classification
- // For categorical variables in regression and binary classification,
- // each bin is a category. The bins are sorted and they
- // are ordered by calculating the centroid of their corresponding labels.
- sampledInput.map(lp => (lp.features(featureIndex), lp.label))
- .groupBy(_._1)
- .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
- }
- }
-
- logDebug("centroid for categories = " + centroidForCategories.mkString(","))
-
- // Check for missing categorical variables and putting them last in the sorted list.
- val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
- for (i <- 0 until featureCategories) {
- if (centroidForCategories.contains(i)) {
- fullCentroidForCategories(i) = centroidForCategories(i)
- } else {
- fullCentroidForCategories(i) = Double.MaxValue
- }
- }
-
- // bins sorted by centroids
- val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
-
- logDebug("centroid for categorical variable = " + categoriesSortedByCentroid)
-
- var categoriesForSplit = List[Double]()
- categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
- case ((key, value), index) =>
- categoriesForSplit = key :: categoriesForSplit
- splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue,
- Categorical, categoriesForSplit)
- bins(featureIndex)(index) = {
- if (index == 0) {
- new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
- splits(featureIndex)(0), Categorical, key)
- } else {
- new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
- Categorical, key)
- }
- }
+ splitIndex += 1
}
+ } else {
+ // Ordered features
+ // Bins correspond to feature values, so we do not need to compute splits or bins
+ // beforehand. Splits are constructed as needed during training.
+ splits(featureIndex) = new Array[Split](0)
+ bins(featureIndex) = new Array[Bin](0)
}
}
featureIndex += 1
}
-
- // Find all bins.
- featureIndex = 0
- while (featureIndex < numFeatures) {
- val isFeatureContinuous = metadata.isContinuous(featureIndex)
- if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
- bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
- splits(featureIndex)(0), Continuous, Double.MinValue)
- for (index <- 1 until numBins - 1) {
- val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
- Continuous, Double.MinValue)
- bins(featureIndex)(index) = bin
- }
- bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),
- new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
- }
- featureIndex += 1
- }
(splits, bins)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index cfc8192a85abd..23f74d5360fe5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -50,7 +50,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
- * 128 MB.
+ * 256 MB.
*/
@Experimental
class Strategy (
@@ -58,10 +58,10 @@ class Strategy (
val impurity: Impurity,
val maxDepth: Int,
val numClassesForClassification: Int = 2,
- val maxBins: Int = 100,
+ val maxBins: Int = 32,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
- val maxMemoryInMB: Int = 128) extends Serializable {
+ val maxMemoryInMB: Int = 256) extends Serializable {
if (algo == Classification) {
require(numClassesForClassification >= 2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
new file mode 100644
index 0000000000000..866d85a79bea1
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.apache.spark.mllib.tree.impurity._
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ */
+private[tree] class DTStatsAggregator(
+ val metadata: DecisionTreeMetadata,
+ val numNodes: Int) extends Serializable {
+
+ /**
+ * [[ImpurityAggregator]] instance specifying the impurity type.
+ */
+ val impurityAggregator: ImpurityAggregator = metadata.impurity match {
+ case Gini => new GiniAggregator(metadata.numClasses)
+ case Entropy => new EntropyAggregator(metadata.numClasses)
+ case Variance => new VarianceAggregator()
+ case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
+ }
+
+ /**
+ * Number of elements (Double values) used for the sufficient statistics of each bin.
+ */
+ val statsSize: Int = impurityAggregator.statsSize
+
+ val numFeatures: Int = metadata.numFeatures
+
+ /**
+ * Number of bins for each feature. This is indexed by the feature index.
+ */
+ val numBins: Array[Int] = metadata.numBins
+
+ /**
+ * Number of splits for the given feature.
+ */
+ def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
+
+ /**
+ * Indicator for each feature of whether that feature is an unordered feature.
+ * TODO: Is Array[Boolean] any faster?
+ */
+ def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
+
+ /**
+ * Offset for each feature for calculating indices into the [[allStats]] array.
+ */
+ private val featureOffsets: Array[Int] = {
+ def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
+ if (isUnordered(featureIndex)) {
+ total + 2 * numBins(featureIndex)
+ } else {
+ total + numBins(featureIndex)
+ }
+ }
+ Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
+ }
+
+ /**
+ * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
+ */
+ private val nodeStride: Int = featureOffsets.last
+
+ /**
+ * Total number of elements stored in this aggregator.
+ */
+ val allStatsSize: Int = numNodes * nodeStride
+
+ /**
+ * Flat array of elements.
+ * Index for start of stats for a (node, feature, bin) is:
+ * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+ * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
+ * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
+ */
+ val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+ /**
+ * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+ * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getNodeFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (node, feature, left/right child) offset from
+ * [[getLeftRightNodeFeatureOffsets]].
+ */
+ def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
+ impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
+ }
+
+ /**
+ * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ */
+ def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
+ val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+ impurityAggregator.update(allStats, i, label)
+ }
+
+ /**
+ * Pre-compute node offset for use with [[nodeUpdate]].
+ */
+ def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+
+ /**
+ * Faster version of [[update]].
+ * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
+ */
+ def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
+ val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
+ impurityAggregator.update(allStats, i, label)
+ }
+
+ /**
+ * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * For ordered features only.
+ */
+ def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+ require(!isUnordered(featureIndex),
+ s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
+ s" for unordered feature $featureIndex.")
+ nodeIndex * nodeStride + featureOffsets(featureIndex)
+ }
+
+ /**
+ * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * For unordered features only.
+ */
+ def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
+ require(isUnordered(featureIndex),
+ s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
+ s" but was called for ordered feature $featureIndex.")
+ val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
+ (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
+ }
+
+ /**
+ * Faster version of [[update]].
+ * Update the stats for a given (node, feature, bin), using the given label.
+ * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getNodeFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (node, feature, left/right child) offset from
+ * [[getLeftRightNodeFeatureOffsets]].
+ */
+ def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
+ impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
+ }
+
+ /**
+ * For a given (node, feature), merge the stats for two bins.
+ * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getNodeFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (node, feature, left/right child) offset from
+ * [[getLeftRightNodeFeatureOffsets]].
+ * @param binIndex The other bin is merged into this bin.
+ * @param otherBinIndex This bin is not modified.
+ */
+ def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+ impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
+ nodeFeatureOffset + otherBinIndex * statsSize)
+ }
+
+ /**
+ * Merge this aggregator with another, and returns this aggregator.
+ * This method modifies this aggregator in-place.
+ */
+ def merge(other: DTStatsAggregator): DTStatsAggregator = {
+ require(allStatsSize == other.allStatsSize,
+ s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+ + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+ var i = 0
+ // TODO: Test BLAS.axpy
+ while (i < allStatsSize) {
+ allStats(i) += other.allStats(i)
+ i += 1
+ }
+ this
+ }
+
+}
+
+private[tree] object DTStatsAggregator extends Serializable {
+
+ /**
+ * Combines two aggregates (modifying the first) and returns the combination.
+ */
+ def binCombOp(
+ agg1: DTStatsAggregator,
+ agg2: DTStatsAggregator): DTStatsAggregator = {
+ agg1.merge(agg2)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index d9eda354dc986..e95add7558bcf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.rdd.RDD
-
/**
* Learning and dataset metadata for DecisionTree.
*
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
+ * @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
+ * @param numBins Number of bins for each feature.
*/
private[tree] class DecisionTreeMetadata(
val numFeatures: Int,
@@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata(
val maxBins: Int,
val featureArity: Map[Int, Int],
val unorderedFeatures: Set[Int],
+ val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {
@@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata(
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+ /**
+ * Number of splits for the given feature.
+ * For unordered features, there are 2 bins per split.
+ * For ordered features, there is 1 more bin than split.
+ */
+ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
+ numBins(featureIndex) >> 1
+ } else {
+ numBins(featureIndex) - 1
+ }
+
}
private[tree] object DecisionTreeMetadata {
+ /**
+ * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
+ * This computes which categorical features will be ordered vs. unordered,
+ * as well as the number of splits and bins for each feature.
+ */
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
val numFeatures = input.take(1)(0).features.size
@@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata {
case Regression => 0
}
- val maxBins = math.min(strategy.maxBins, numExamples).toInt
- val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
+ val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+
+ // We check the number of bins here against maxPossibleBins.
+ // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
+ // based on the number of training examples.
+ if (strategy.categoricalFeaturesInfo.nonEmpty) {
+ val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+ require(maxCategoriesPerFeature <= maxPossibleBins,
+ s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
+ s"in categorical features (= $maxCategoriesPerFeature)")
+ }
val unorderedFeatures = new mutable.HashSet[Int]()
+ val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
if (numClasses > 2) {
- strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
- if (k - 1 < log2MaxBinsp1) {
- // Note: The above check is equivalent to checking:
- // numUnorderedBins = (1 << k - 1) - 1 < maxBins
- unorderedFeatures.add(f)
+ // Multiclass classification
+ val maxCategoriesForUnorderedFeature =
+ ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
+ strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+ // Decide if some categorical features should be treated as unordered features,
+ // which require 2 * ((1 << numCategories - 1) - 1) bins.
+ // We do this check with log values to prevent overflows in case numCategories is large.
+ // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+ if (numCategories <= maxCategoriesForUnorderedFeature) {
+ unorderedFeatures.add(featureIndex)
+ numBins(featureIndex) = numUnorderedBins(numCategories)
} else {
- // TODO: Allow this case, where we simply will know nothing about some categories?
- require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
- s"in categorical features (>= $k)")
+ numBins(featureIndex) = numCategories
}
}
} else {
- strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
- require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
- s"in categorical features (>= $k)")
+ // Binary classification or regression
+ strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+ numBins(featureIndex) = numCategories
}
}
- new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
- strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
+ new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
+ strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy)
}
+ /**
+ * Given the arity of a categorical feature (arity = number of categories),
+ * return the number of bins for the feature if it is to be treated as an unordered feature.
+ * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
+ * there are math.pow(2, arity - 1) - 1 such splits.
+ * Each split has 2 corresponding bins.
+ */
+ def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 170e43e222083..35e361ae309cc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -48,54 +48,63 @@ private[tree] object TreePoint {
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param bins Bins for features, of size (numFeatures, numBins).
- * @param metadata Learning and dataset metadata
+ * @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
+ // Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
+ val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
+ val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
+ var featureIndex = 0
+ while (featureIndex < metadata.numFeatures) {
+ featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
+ isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
+ featureIndex += 1
+ }
input.map { x =>
- TreePoint.labeledPointToTreePoint(x, bins, metadata)
+ TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
}
}
/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
+ * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
+ * for categorical features.
+ * @param isUnordered Array index by feature, with value true for unordered categorical features.
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
bins: Array[Array[Bin]],
- metadata: DecisionTreeMetadata): TreePoint = {
-
+ featureArity: Array[Int],
+ isUnordered: Array[Boolean]): TreePoint = {
val numFeatures = labeledPoint.features.size
- val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
- arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
- metadata.isUnordered(featureIndex), bins, metadata.featureArity)
+ arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
+ isUnordered(featureIndex), bins)
featureIndex += 1
}
-
new TreePoint(labeledPoint.label, arr)
}
/**
* Find bin for one (labeledPoint, feature).
*
+ * @param featureArity 0 for continuous features; number of categories for categorical features.
* @param isUnorderedFeature (only applies if feature is categorical)
* @param bins Bins for features, of size (numFeatures, numBins).
- * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
- isFeatureContinuous: Boolean,
+ featureArity: Int,
isUnorderedFeature: Boolean,
- bins: Array[Array[Bin]],
- categoricalFeaturesInfo: Map[Int, Int]): Int = {
+ bins: Array[Array[Bin]]): Int = {
/**
* Binary search helper method for continuous feature.
@@ -121,44 +130,7 @@ private[tree] object TreePoint {
-1
}
- /**
- * Sequential search helper method to find bin for categorical feature in multiclass
- * classification. The category is returned since each category can belong to multiple
- * splits. The actual left/right child allocation per split is performed in the
- * sequential phase of the bin aggregate operation.
- */
- def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
- labeledPoint.features(featureIndex).toInt
- }
-
- /**
- * Sequential search helper method to find bin for categorical feature
- * (for classification and regression).
- */
- def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
- val featureCategories = categoricalFeaturesInfo(featureIndex)
- val featureValue = labeledPoint.features(featureIndex)
- var binIndex = 0
- while (binIndex < featureCategories) {
- val bin = bins(featureIndex)(binIndex)
- val categories = bin.highSplit.categories
- if (categories.contains(featureValue)) {
- return binIndex
- }
- binIndex += 1
- }
- if (featureValue < 0 || featureValue >= featureCategories) {
- throw new IllegalArgumentException(
- s"DecisionTree given invalid data:" +
- s" Feature $featureIndex is categorical with values in" +
- s" {0,...,${featureCategories - 1}," +
- s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
- }
- -1
- }
-
- if (isFeatureContinuous) {
+ if (featureArity == 0) {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
if (binIndex == -1) {
@@ -168,18 +140,17 @@ private[tree] object TreePoint {
}
binIndex
} else {
- // Perform sequential search to find bin for categorical features.
- val binIndex = if (isUnorderedFeature) {
- sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
- } else {
- sequentialBinSearchForOrderedCategoricalFeature()
- }
- if (binIndex == -1) {
- throw new RuntimeException("No bin was found for categorical feature." +
- " This error can occur when given invalid data values (such as NaN)." +
- s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
+ // Categorical feature bins are indexed by feature values.
+ val featureValue = labeledPoint.features(featureIndex)
+ if (featureValue < 0 || featureValue >= featureArity) {
+ throw new IllegalArgumentException(
+ s"DecisionTree given invalid data:" +
+ s" Feature $featureIndex is categorical with values in" +
+ s" {0,...,${featureArity - 1}," +
+ s" but a data point gives it value $featureValue.\n" +
+ " Bad data point: " + labeledPoint.toString)
}
- binIndex
+ featureValue.toInt
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 96d2471e1f88c..1c8afc2d0f4bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -74,3 +74,87 @@ object Entropy extends Impurity {
def instance = this
}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses Number of classes for label.
+ */
+private[tree] class EntropyAggregator(numClasses: Int)
+ extends ImpurityAggregator(numClasses) with Serializable {
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ if (label >= statsSize) {
+ throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+ s" but requires label < numClasses (= $statsSize).")
+ }
+ allStats(offset + label.toInt) += 1
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
+ new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+ }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[EntropyAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double = Entropy.calculate(stats, stats.sum)
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long = stats.sum.toLong
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double = if (count == 0) {
+ 0
+ } else {
+ indexOfLargestArrayElement(stats)
+ }
+
+ /**
+ * Probability of the label given by [[predict]].
+ */
+ override def prob(label: Double): Double = {
+ val lbl = label.toInt
+ require(lbl < stats.length,
+ s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ val cnt = count
+ if (cnt == 0) {
+ 0
+ } else {
+ stats(lbl) / cnt
+ }
+ }
+
+ override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])"
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index d586f449048bb..5cfdf345d163c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -70,3 +70,87 @@ object Gini extends Impurity {
def instance = this
}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses Number of classes for label.
+ */
+private[tree] class GiniAggregator(numClasses: Int)
+ extends ImpurityAggregator(numClasses) with Serializable {
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ if (label >= statsSize) {
+ throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+ s" but requires label < numClasses (= $statsSize).")
+ }
+ allStats(offset + label.toInt) += 1
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
+ new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+ }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: GiniCalculator = new GiniCalculator(stats.clone())
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double = Gini.calculate(stats, stats.sum)
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long = stats.sum.toLong
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double = if (count == 0) {
+ 0
+ } else {
+ indexOfLargestArrayElement(stats)
+ }
+
+ /**
+ * Probability of the label given by [[predict]].
+ */
+ override def prob(label: Double): Double = {
+ val lbl = label.toInt
+ require(lbl < stats.length,
+ s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ val cnt = count
+ if (cnt == 0) {
+ 0
+ } else {
+ stats(lbl) / cnt
+ }
+ }
+
+ override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])"
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 92b0c7b4a6fbc..5a047d6cb5480 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -22,6 +22,9 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
/**
* :: Experimental ::
* Trait for calculating information gain.
+ * This trait is used for
+ * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]]
+ * (b) calculating impurity values from sufficient statistics.
*/
@Experimental
trait Impurity extends Serializable {
@@ -47,3 +50,127 @@ trait Impurity extends Serializable {
@DeveloperApi
def calculate(count: Double, sum: Double, sumSquares: Double): Double
}
+
+/**
+ * Interface for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param statsSize Length of the vector of sufficient statistics for one bin.
+ */
+private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable {
+
+ /**
+ * Merge the stats from one bin into another.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for (node, feature, bin) which is modified by the merge.
+ * @param otherOffset Start index of stats for (node, feature, other bin) which is not modified.
+ */
+ def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = {
+ var i = 0
+ while (i < statsSize) {
+ allStats(offset + i) += allStats(otherOffset + i)
+ i += 1
+ }
+ }
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) {
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: ImpurityCalculator
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double
+
+ /**
+ * Add the stats from another calculator into this one, modifying and returning this calculator.
+ */
+ def add(other: ImpurityCalculator): ImpurityCalculator = {
+ require(stats.size == other.stats.size,
+ s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
+ s" Sizes are ${stats.size} and ${other.stats.size}.")
+ var i = 0
+ while (i < other.stats.size) {
+ stats(i) += other.stats(i)
+ i += 1
+ }
+ this
+ }
+
+ /**
+ * Subtract the stats from another calculator from this one, modifying and returning this
+ * calculator.
+ */
+ def subtract(other: ImpurityCalculator): ImpurityCalculator = {
+ require(stats.size == other.stats.size,
+ s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
+ s" Sizes are ${stats.size} and ${other.stats.size}.")
+ var i = 0
+ while (i < other.stats.size) {
+ stats(i) -= other.stats(i)
+ i += 1
+ }
+ this
+ }
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double
+
+ /**
+ * Probability of the label given by [[predict]], or -1 if no probability is available.
+ */
+ def prob(label: Double): Double = -1
+
+ /**
+ * Return the index of the largest array element.
+ * Fails if the array is empty.
+ */
+ protected def indexOfLargestArrayElement(array: Array[Double]): Int = {
+ val result = array.foldLeft(-1, Double.MinValue, 0) {
+ case ((maxIndex, maxValue, currentIndex), currentValue) =>
+ if (currentValue > maxValue) {
+ (currentIndex, currentValue, currentIndex + 1)
+ } else {
+ (maxIndex, maxValue, currentIndex + 1)
+ }
+ }
+ if (result._1 < 0) {
+ throw new RuntimeException("ImpurityCalculator internal error:" +
+ " indexOfLargestArrayElement failed")
+ }
+ result._1
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index f7d99a40eb380..e9ccecb1b8067 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -61,3 +61,75 @@ object Variance extends Impurity {
def instance = this
}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ */
+private[tree] class VarianceAggregator()
+ extends ImpurityAggregator(statsSize = 3) with Serializable {
+
+ /**
+ * Update stats for one (node, feature, bin) with the given label.
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ allStats(offset) += 1
+ allStats(offset + 1) += label
+ allStats(offset + 2) += label * label
+ }
+
+ /**
+ * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+ * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
+ * @param offset Start index of stats for this (node, feature, bin).
+ */
+ def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
+ new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+ }
+
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[GiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats Array of sufficient statistics for a (node, feature, bin).
+ */
+private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+
+ require(stats.size == 3,
+ s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
+ s" but was given array of length ${stats.size}.")
+
+ /**
+ * Make a deep copy of this [[ImpurityCalculator]].
+ */
+ def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+
+ /**
+ * Calculate the impurity from the stored sufficient statistics.
+ */
+ def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
+
+ /**
+ * Number of data points accounted for in the sufficient statistics.
+ */
+ def count: Long = stats(0).toLong
+
+ /**
+ * Prediction which should be made based on the sufficient statistics.
+ */
+ def predict: Double = if (count == 0) {
+ 0
+ } else {
+ stats(1) / count
+ }
+
+ override def toString: String = {
+ s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})"
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index af35d88f713e5..0cad473782af1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._
/**
- * Used for "binning" the features bins for faster best split calculation.
+ * Used for "binning" the feature values for faster best split calculation.
*
* For a continuous feature, the bin is determined by a low and a high split,
* where an example with featureValue falls into the bin s.t.
@@ -30,13 +30,16 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* bins, splits, and feature values. The bin is determined by category/feature value.
* However, the bins are not necessarily ordered by feature value;
* they are ordered using impurity.
+ *
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
+ * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
+ * partitionings of categories into 2 disjoint, non-empty sets.
*
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
- * accepted in the bin
+ * accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin for ordered features
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 0eee6262781c1..5b8a4cbed2306 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -24,8 +24,13 @@ import org.apache.spark.mllib.linalg.Vector
/**
* :: DeveloperApi ::
- * Node in a decision tree
- * @param id integer node id
+ * Node in a decision tree.
+ *
+ * About node indexing:
+ * Nodes are indexed from 1. Node 1 is the root; nodes 2, 3 are the left, right children.
+ * Node index 0 is not used.
+ *
+ * @param id integer node id, from 1
* @param predict predicted value at the node
* @param isLeaf whether the leaf is a node
* @param split split to calculate left and right nodes
@@ -51,17 +56,13 @@ class Node (
* @param nodes array of nodes
*/
def build(nodes: Array[Node]): Unit = {
-
- logDebug("building node " + id + " at level " +
- (scala.math.log(id + 1)/scala.math.log(2)).toInt )
+ logDebug("building node " + id + " at level " + Node.indexToLevel(id))
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
if (!isLeaf) {
- val leftNodeIndex = id * 2 + 1
- val rightNodeIndex = id * 2 + 2
- leftNode = Some(nodes(leftNodeIndex))
- rightNode = Some(nodes(rightNodeIndex))
+ leftNode = Some(nodes(Node.leftChildIndex(id)))
+ rightNode = Some(nodes(Node.rightChildIndex(id)))
leftNode.get.build(nodes)
rightNode.get.build(nodes)
}
@@ -96,24 +97,20 @@ class Node (
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
*/
- private[tree] def numDescendants: Int = {
- if (isLeaf) {
- 0
- } else {
- 2 + leftNode.get.numDescendants + rightNode.get.numDescendants
- }
+ private[tree] def numDescendants: Int = if (isLeaf) {
+ 0
+ } else {
+ 2 + leftNode.get.numDescendants + rightNode.get.numDescendants
}
/**
* Get depth of tree from this node.
* E.g.: Depth 0 means this is a leaf node.
*/
- private[tree] def subtreeDepth: Int = {
- if (isLeaf) {
- 0
- } else {
- 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
- }
+ private[tree] def subtreeDepth: Int = if (isLeaf) {
+ 0
+ } else {
+ 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
}
/**
@@ -148,3 +145,49 @@ class Node (
}
}
+
+private[tree] object Node {
+
+ /**
+ * Return the index of the left child of this node.
+ */
+ def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
+
+ /**
+ * Return the index of the right child of this node.
+ */
+ def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
+
+ /**
+ * Get the parent index of the given node, or 0 if it is the root.
+ */
+ def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1
+
+ /**
+ * Return the level of a tree which the given node is in.
+ */
+ def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) {
+ throw new IllegalArgumentException(s"0 is not a valid node index.")
+ } else {
+ java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex))
+ }
+
+ /**
+ * Returns true if this is a left child.
+ * Note: Returns false for the root.
+ */
+ def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0
+
+ /**
+ * Return the maximum number of nodes which can be in the given level of the tree.
+ * @param level Level of tree (0 = root).
+ */
+ def maxNodesInLevel(level: Int): Int = 1 << level
+
+ /**
+ * Return the index of the first node in the given level.
+ * @param level Level of tree (0 = root).
+ */
+ def startIndexInLevel(level: Int): Int = 1 << level
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2f36fd907772c..69482f2acbb40 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -21,15 +21,15 @@ import scala.collection.JavaConverters._
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.mllib.regression.LabeledPoint
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -59,12 +59,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
}
- test("split and bin calculation") {
+ test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(bins.length === 2)
@@ -72,7 +73,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(0).length === 100)
}
- test("split and bin calculation for categorical variables") {
+ test("Binary classification with binary (ordered) categorical features:" +
+ " split and bin calculation") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -83,77 +85,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
assert(splits.length === 2)
assert(bins.length === 2)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
-
- // Check splits.
-
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(1.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 2)
- assert(splits(0)(1).categories.contains(1.0))
- assert(splits(0)(1).categories.contains(0.0))
-
- assert(splits(0)(2) === null)
-
- assert(splits(1)(0).feature === 1)
- assert(splits(1)(0).threshold === Double.MinValue)
- assert(splits(1)(0).featureType === Categorical)
- assert(splits(1)(0).categories.length === 1)
- assert(splits(1)(0).categories.contains(0.0))
-
- assert(splits(1)(1).feature === 1)
- assert(splits(1)(1).threshold === Double.MinValue)
- assert(splits(1)(1).featureType === Categorical)
- assert(splits(1)(1).categories.length === 2)
- assert(splits(1)(1).categories.contains(1.0))
- assert(splits(1)(1).categories.contains(0.0))
-
- assert(splits(1)(2) === null)
-
- // Check bins.
-
- assert(bins(0)(0).category === 1.0)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(1.0))
-
- assert(bins(0)(1).category === 0.0)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).lowSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.length === 2)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.contains(0.0))
-
- assert(bins(0)(2) === null)
-
- assert(bins(1)(0).category === 0.0)
- assert(bins(1)(0).lowSplit.categories.length === 0)
- assert(bins(1)(0).highSplit.categories.length === 1)
- assert(bins(1)(0).highSplit.categories.contains(0.0))
-
- assert(bins(1)(1).category === 1.0)
- assert(bins(1)(1).lowSplit.categories.length === 1)
- assert(bins(1)(1).lowSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.length === 2)
- assert(bins(1)(1).highSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.contains(1.0))
-
- assert(bins(1)(2) === null)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
}
- test("split and bin calculations for categorical variables with no sample for one category") {
+ test("Binary classification with 3-ary (ordered) categorical features," +
+ " with no samples for one category") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -164,104 +109,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-
- // Check splits.
-
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(1.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 2)
- assert(splits(0)(1).categories.contains(1.0))
- assert(splits(0)(1).categories.contains(0.0))
-
- assert(splits(0)(2).feature === 0)
- assert(splits(0)(2).threshold === Double.MinValue)
- assert(splits(0)(2).featureType === Categorical)
- assert(splits(0)(2).categories.length === 3)
- assert(splits(0)(2).categories.contains(1.0))
- assert(splits(0)(2).categories.contains(0.0))
- assert(splits(0)(2).categories.contains(2.0))
-
- assert(splits(0)(3) === null)
-
- assert(splits(1)(0).feature === 1)
- assert(splits(1)(0).threshold === Double.MinValue)
- assert(splits(1)(0).featureType === Categorical)
- assert(splits(1)(0).categories.length === 1)
- assert(splits(1)(0).categories.contains(0.0))
-
- assert(splits(1)(1).feature === 1)
- assert(splits(1)(1).threshold === Double.MinValue)
- assert(splits(1)(1).featureType === Categorical)
- assert(splits(1)(1).categories.length === 2)
- assert(splits(1)(1).categories.contains(1.0))
- assert(splits(1)(1).categories.contains(0.0))
-
- assert(splits(1)(2).feature === 1)
- assert(splits(1)(2).threshold === Double.MinValue)
- assert(splits(1)(2).featureType === Categorical)
- assert(splits(1)(2).categories.length === 3)
- assert(splits(1)(2).categories.contains(1.0))
- assert(splits(1)(2).categories.contains(0.0))
- assert(splits(1)(2).categories.contains(2.0))
-
- assert(splits(1)(3) === null)
-
- // Check bins.
-
- assert(bins(0)(0).category === 1.0)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(1.0))
-
- assert(bins(0)(1).category === 0.0)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).lowSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.length === 2)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.contains(0.0))
-
- assert(bins(0)(2).category === 2.0)
- assert(bins(0)(2).lowSplit.categories.length === 2)
- assert(bins(0)(2).lowSplit.categories.contains(1.0))
- assert(bins(0)(2).lowSplit.categories.contains(0.0))
- assert(bins(0)(2).highSplit.categories.length === 3)
- assert(bins(0)(2).highSplit.categories.contains(1.0))
- assert(bins(0)(2).highSplit.categories.contains(0.0))
- assert(bins(0)(2).highSplit.categories.contains(2.0))
-
- assert(bins(0)(3) === null)
-
- assert(bins(1)(0).category === 0.0)
- assert(bins(1)(0).lowSplit.categories.length === 0)
- assert(bins(1)(0).highSplit.categories.length === 1)
- assert(bins(1)(0).highSplit.categories.contains(0.0))
-
- assert(bins(1)(1).category === 1.0)
- assert(bins(1)(1).lowSplit.categories.length === 1)
- assert(bins(1)(1).lowSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.length === 2)
- assert(bins(1)(1).highSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.contains(1.0))
-
- assert(bins(1)(2).category === 2.0)
- assert(bins(1)(2).lowSplit.categories.length === 2)
- assert(bins(1)(2).lowSplit.categories.contains(0.0))
- assert(bins(1)(2).lowSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.length === 3)
- assert(bins(1)(2).highSplit.categories.contains(0.0))
- assert(bins(1)(2).highSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.contains(2.0))
-
- assert(bins(1)(3) === null)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
}
test("extract categories from a number for multiclass classification") {
@@ -270,8 +127,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
}
- test("split and bin calculations for unordered categorical variables with multiclass " +
- "classification") {
+ test("Multiclass classification with unordered categorical features:" +
+ " split and bin calculations") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -282,8 +139,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 3)
+ assert(bins(0).length === 6)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
@@ -321,10 +185,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(1.0))
- assert(splits(0)(3) === null)
- assert(splits(1)(3) === null)
-
-
// Check bins.
assert(bins(0)(0).category === Double.MinValue)
@@ -360,13 +220,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(0.0))
- assert(bins(0)(3) === null)
- assert(bins(1)(3) === null)
-
}
- test("split and bin calculations for ordered categorical variables with multiclass " +
- "classification") {
+ test("Multiclass classification with ordered categorical features: split and bin calculations") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
assert(arr.length === 3000)
val rdd = sc.parallelize(arr)
@@ -377,52 +233,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+ // 2^10 - 1 > 100, so categorical features will be ordered
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-
- // 2^10 - 1 > 100, so categorical variables will be ordered
-
- assert(splits(0)(0).feature === 0)
- assert(splits(0)(0).threshold === Double.MinValue)
- assert(splits(0)(0).featureType === Categorical)
- assert(splits(0)(0).categories.length === 1)
- assert(splits(0)(0).categories.contains(1.0))
-
- assert(splits(0)(1).feature === 0)
- assert(splits(0)(1).threshold === Double.MinValue)
- assert(splits(0)(1).featureType === Categorical)
- assert(splits(0)(1).categories.length === 2)
- assert(splits(0)(1).categories.contains(2.0))
-
- assert(splits(0)(2).feature === 0)
- assert(splits(0)(2).threshold === Double.MinValue)
- assert(splits(0)(2).featureType === Categorical)
- assert(splits(0)(2).categories.length === 3)
- assert(splits(0)(2).categories.contains(2.0))
- assert(splits(0)(2).categories.contains(1.0))
-
- assert(splits(0)(10) === null)
- assert(splits(1)(10) === null)
-
-
- // Check bins.
-
- assert(bins(0)(0).category === 1.0)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(1.0))
- assert(bins(0)(1).category === 2.0)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).highSplit.categories.length === 2)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(0)(1).highSplit.categories.contains(2.0))
-
- assert(bins(0)(10) === null)
-
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
}
- test("classification stump with all categorical variables") {
+ test("Binary classification stump with ordered categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -433,15 +258,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ // no bins or splits pre-computed for ordered categorical features
+ assert(splits(0).length === 0)
+ assert(bins(0).length === 0)
+
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
- assert(split.categories.length === 1)
- assert(split.categories.contains(1.0))
+ assert(split.categories === List(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)
@@ -452,7 +285,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}
- test("regression stump with all categorical variables") {
+ test("Regression stump with 3-ary (ordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -462,10 +295,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)
val split = bestSplits(0)._1
@@ -480,7 +317,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}
- test("regression stump with categorical variables of arity 2") {
+ test("Regression stump with binary (ordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -490,6 +327,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
validateRegressor(model, arr, 0.0)
@@ -497,22 +337,24 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(model.depth === 1)
}
- test("stump with fixed label 0 for Gini") {
+ test("Binary classification stump with fixed label 0 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+ val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -521,22 +363,24 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.rightImpurity === 0)
}
- test("stump with fixed label 1 for Gini") {
+ test("Binary classification stump with fixed label 1 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 2, 100)
+ val strategy = new Strategy(Classification, Gini, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -546,22 +390,24 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 1)
}
- test("stump with fixed label 0 for Entropy") {
+ test("Binary classification stump with fixed label 0 for Entropy") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -571,22 +417,24 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 0)
}
- test("stump with fixed label 1 for Entropy") {
+ test("Binary classification stump with fixed label 1 for Entropy") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
+ numClassesForClassification = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -596,7 +444,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 1)
}
- test("second level node building with/without groups") {
+ test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
@@ -607,18 +455,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
// Train a 1-node model
val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val nodes: Array[Node] = new Array[Node](7)
- nodes(0) = modelOneNode.topNode
- nodes(0).leftNode = None
- nodes(0).rightNode = None
+ val nodes: Array[Node] = new Array[Node](8)
+ nodes(1) = modelOneNode.topNode
+ nodes(1).leftNode = None
+ nodes(1).rightNode = None
- val parentImpurities = Array(0.5, 0.5, 0.5)
+ val parentImpurities = Array(0, 0.5, 0.5, 0.5)
// Single group second level tree construction.
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
@@ -648,16 +494,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
}
}
- test("stump with categorical variables for multiclass classification") {
+ test("Multiclass classification stump with 3-ary (unordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(strategy.isMulticlassClassification)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
+
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -668,7 +517,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.featureType === Categorical)
}
- test("stump with 1 continuous variable for binary classification, to check off-by-1 error") {
+ test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
@@ -684,26 +533,27 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(model.depth === 1)
}
- test("stump with 2 continuous variables for binary classification") {
+ test("Binary classification stump with 2 continuous features") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 2)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
assert(model.topNode.split.get.feature === 1)
}
- test("stump with categorical variables for multiclass classification, with just enough bins") {
- val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
+ test("Multiclass classification stump with unordered categorical features," +
+ " with just enough bins") {
+ val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -711,6 +561,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
+ assert(metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 1.0)
@@ -719,7 +571,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -733,11 +585,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(gain.rightImpurity === 0)
}
- test("stump with continuous variables for multiclass classification") {
+ test("Multiclass classification stump with continuous features") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3)
+ numClassesForClassification = 3, maxBins = 100)
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
@@ -746,7 +598,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -759,20 +611,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
}
- test("stump with continuous + categorical variables for multiclass classification") {
+ test("Multiclass classification stump with continuous + unordered categorical features") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
+ numClassesForClassification = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy)
validateClassifier(model, arr, 0.9)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -784,17 +637,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.threshold < 2020)
}
- test("stump with categorical variables for ordered multiclass classification") {
+ test("Multiclass classification stump with 10-ary (ordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+ numClassesForClassification = 3, maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ assert(!metadata.isUnordered(featureIndex = 0))
+ assert(!metadata.isUnordered(featureIndex = 1))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0,
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
new Array[Node](0), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -805,6 +661,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.featureType === Categorical)
}
+ test("Multiclass classification tree with 10-ary (ordered) categorical features," +
+ " with just enough bins") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
+ numClassesForClassification = 3, maxBins = 10,
+ categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+ assert(strategy.isMulticlassClassification)
+
+ val model = DecisionTree.train(rdd, strategy)
+ validateClassifier(model, arr, 0.6)
+ }
}
@@ -899,5 +767,4 @@ object DecisionTreeSuite {
arr
}
-
}
diff --git a/pom.xml b/pom.xml
index d4650aff7b364..3134e47394590 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
pom
Spark Project Parent POM
http://spark.apache.org/
@@ -220,6 +220,18 @@
false
+
+
+ spark-staging-1030
+ Spark 1.1.0 Staging (1030)
+ https://repository.apache.org/content/repositories/orgapachespark-1030/
+
+ true
+
+
+ false
+
+
@@ -875,7 +887,7 @@
org.scalatest
scalatest-maven-plugin
- 1.0-RC2
+ 1.0
${project.build.directory}/surefire-reports
.
@@ -886,6 +898,7 @@
true
${session.executionRootDirectory}
1
+ 0
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 034ba6a7bf50f..0f5d71afcf616 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -85,7 +85,7 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "1.0.0"
+ val previousSparkVersion = "1.1.0"
val fullId = "spark-" + projectRef.project + "_2.10"
mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 855d5cc8cf3fd..46b78bd5c7061 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -33,6 +33,18 @@ import com.typesafe.tools.mima.core._
object MimaExcludes {
def excludes(version: String) =
version match {
+ case v if v.startsWith("1.2") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("deploy"),
+ MimaBuild.excludeSparkPackage("graphx")
+ ) ++
+ // This is @DeveloperAPI, but Mima still gives false-positives:
+ MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++
+ Seq(
+ // This is @Experimental, but Mima still gives false-positives:
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.foreachAsync")
+ )
case v if v.startsWith("1.1") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index a26c2c90cb321..45f6d2973ea90 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -184,7 +184,7 @@ object OldDeps {
def versionArtifact(id: String): Option[sbt.ModuleID] = {
val fullId = id + "_2.10"
- Some("org.apache.spark" % fullId % "1.0.0")
+ Some("org.apache.spark" % fullId % "1.1.0")
}
def oldDepsSettings() = Defaults.defaultSettings ++ Seq(
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 68062483dedaa..80e51d1a583a0 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -657,7 +657,6 @@ def save_partial(self, obj):
def save_file(self, obj):
"""Save a file"""
import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
- from ..transport.adapter import SerializingAdapter
if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
@@ -691,13 +690,10 @@ def save_file(self, obj):
tmpfile.close()
if tst != '':
raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name)
- elif fsize > SerializingAdapter.max_transmit_data:
- raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" %
- (name,SerializingAdapter.max_transmit_data))
else:
try:
tmpfile = file(name)
- contents = tmpfile.read(SerializingAdapter.max_transmit_data)
+ contents = tmpfile.read()
tmpfile.close()
except IOError:
raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 5a30431568b16..84bc0a3b7ccd0 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -232,6 +232,20 @@ def _ensure_initialized(cls, instance=None, gateway=None):
else:
SparkContext._active_spark_context = instance
+ def __enter__(self):
+ """
+ Enable 'with SparkContext(...) as sc: app(sc)' syntax.
+ """
+ return self
+
+ def __exit__(self, type, value, trace):
+ """
+ Enable 'with SparkContext(...) as sc: app' syntax.
+
+ Specifically stop the context on exit of the with block.
+ """
+ self.stop()
+
@classmethod
def setSystemProperty(cls, key, value):
"""
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
index 3e59c73db85e3..d53c95fd59c25 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/random.py
@@ -28,7 +28,7 @@
__all__ = ['RandomRDDs', ]
-class RandomRDDs:
+class RandomRDDs(object):
"""
Generator methods for creating RDDs comprised of i.i.d samples from
some distribution.
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index a2fade61e9a71..ccc000ac70ba6 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -138,7 +138,7 @@ class DecisionTree(object):
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=4, maxBins=100):
+ impurity="gini", maxDepth=5, maxBins=32):
"""
Train a DecisionTreeModel for classification.
@@ -170,7 +170,7 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
@staticmethod
def trainRegressor(data, categoricalFeaturesInfo,
- impurity="variance", maxDepth=4, maxBins=100):
+ impurity="variance", maxDepth=5, maxBins=32):
"""
Train a DecisionTreeModel for regression.
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 4962d05491c03..1c7b8c809ab5b 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -25,7 +25,7 @@
from pyspark.serializers import NoOpSerializer
-class MLUtils:
+class MLUtils(object):
"""
Helper methods to load, save and pre-process data used in MLlib.
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index dff6fc26fcb18..5667154cb84a8 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -212,11 +212,16 @@ def cache(self):
self.persist(StorageLevel.MEMORY_ONLY_SER)
return self
- def persist(self, storageLevel):
+ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
"""
Set this RDD's storage level to persist its values across operations
after the first time it is computed. This can only be used to assign
a new storage level if the RDD does not have a storage level set yet.
+ If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+
+ >>> rdd = sc.parallelize(["b", "a", "c"])
+ >>> rdd.persist().is_cached
+ True
"""
self.is_cached = True
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
@@ -515,6 +520,30 @@ def __add__(self, other):
raise TypeError
return self.union(other)
+ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash,
+ ascending=True, keyfunc=lambda x: x):
+ """
+ Repartition the RDD according to the given partitioner and, within each resulting partition,
+ sort records by their keys.
+
+ >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
+ >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2)
+ >>> rdd2.glom().collect()
+ [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
+ """
+ if numPartitions is None:
+ numPartitions = self._defaultReducePartitions()
+
+ spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true")
+ memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+ serializer = self._jrdd_deserializer
+
+ def sortPartition(iterator):
+ sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
+ return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
+
+ return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)
+
def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
"""
Sorts this RDD, which is assumed to consist of (key, value) pairs.
@@ -1089,11 +1118,11 @@ def take(self, num):
# we actually cap it at totalParts in runJob.
numPartsToTry = 1
if partsScanned > 0:
- # If we didn't find any rows after the first iteration, just
- # try all partitions next. Otherwise, interpolate the number
- # of partitions we need to try, but overestimate it by 50%.
+ # If we didn't find any rows after the previous iteration,
+ # quadruple and retry. Otherwise, interpolate the number of
+ # partitions we need to try, but overestimate it by 50%.
if len(items) == 0:
- numPartsToTry = totalParts - 1
+ numPartsToTry = partsScanned * 4
else:
numPartsToTry = int(1.5 * num * partsScanned / len(items))
@@ -2070,6 +2099,7 @@ def pipeline_func(split, iterator):
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
+ self._id = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
@@ -2100,6 +2130,11 @@ def _jrdd(self):
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
+ def id(self):
+ if self._id is None:
+ self._id = self._jrdd.id()
+ return self._id
+
def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index fde3c29e5e790..89cf76920e353 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -49,9 +49,9 @@
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
- /__ / .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT
+ /__ / .__/\_,_/_/ /_/\_\ version %s
/_/
-""")
+""" % sc.version)
print("Using Python version %s (%s, %s)" % (
platform.python_version(),
platform.python_build()[0],
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index e7f573cf6da44..53eea6d6cf3ba 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -29,6 +29,7 @@
from pyspark.rdd import RDD, PipelinedRDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
+from pyspark.storagelevel import StorageLevel
from itertools import chain, ifilter, imap
@@ -898,7 +899,7 @@ def __reduce__(self):
return Row
-class SQLContext:
+class SQLContext(object):
"""Main entry point for Spark SQL functionality.
@@ -1524,7 +1525,7 @@ def __init__(self, jschema_rdd, sql_ctx):
self.sql_ctx = sql_ctx
self._sc = sql_ctx._sc
self._jschema_rdd = jschema_rdd
-
+ self._id = None
self.is_cached = False
self.is_checkpointed = False
self.ctx = self.sql_ctx._sc
@@ -1542,9 +1543,10 @@ def _jrdd(self):
self._lazy_jrdd = self._jschema_rdd.javaToPython()
return self._lazy_jrdd
- @property
- def _id(self):
- return self._jrdd.id()
+ def id(self):
+ if self._id is None:
+ self._id = self._jrdd.id()
+ return self._id
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
@@ -1665,7 +1667,7 @@ def cache(self):
self._jschema_rdd.cache()
return self
- def persist(self, storageLevel):
+ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
self.is_cached = True
javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
self._jschema_rdd.persist(javaStorageLevel)
diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py
index 2aa0fb9d2c1ed..676aa0f7144aa 100644
--- a/python/pyspark/storagelevel.py
+++ b/python/pyspark/storagelevel.py
@@ -18,7 +18,7 @@
__all__ = ["StorageLevel"]
-class StorageLevel:
+class StorageLevel(object):
"""
Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 3e74799e82845..bb84ebe72cb24 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -169,6 +169,17 @@ def test_namedtuple(self):
self.assertEquals(p1, p2)
+# Regression test for SPARK-3415
+class CloudPickleTest(unittest.TestCase):
+ def test_pickling_file_handles(self):
+ from pyspark.cloudpickle import dumps
+ from StringIO import StringIO
+ from pickle import load
+ out1 = sys.stderr
+ out2 = load(StringIO(dumps(out1)))
+ self.assertEquals(out1, out2)
+
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):
@@ -281,6 +292,15 @@ def func():
class TestRDDFunctions(PySparkTestCase):
+ def test_id(self):
+ rdd = self.sc.parallelize(range(10))
+ id = rdd.id()
+ self.assertEqual(id, rdd.id())
+ rdd2 = rdd.map(str).filter(bool)
+ id2 = rdd2.id()
+ self.assertEqual(id + 1, id2)
+ self.assertEqual(id2, rdd2.id())
+
def test_failed_sparkcontext_creation(self):
# Regression test for SPARK-1550
self.sc.stop()
@@ -525,6 +545,14 @@ def test_histogram(self):
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))
+ def test_repartitionAndSortWithinPartitions(self):
+ rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
+
+ repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
+ partitions = repartitioned.glom().collect()
+ self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
+ self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
+
class TestSQL(PySparkTestCase):
@@ -1226,6 +1254,35 @@ def test_single_script_on_cluster(self):
self.assertIn("[2, 4, 6]", out)
+class ContextStopTests(unittest.TestCase):
+
+ def test_stop(self):
+ sc = SparkContext()
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ sc.stop()
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+ def test_with(self):
+ with SparkContext() as sc:
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+ def test_with_exception(self):
+ try:
+ with SparkContext() as sc:
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ raise Exception()
+ except:
+ pass
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+ def test_with_stop(self):
+ with SparkContext() as sc:
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ sc.stop()
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
diff --git a/python/run-tests b/python/run-tests
index f2a80b4f1838b..d98840de59d2c 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -19,7 +19,7 @@
# Figure out where the Spark framework is installed
-FWDIR="$(cd `dirname $0`; cd ../; pwd)"
+FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)"
# CD into the python directory to find things on the right path
cd "$FWDIR/python"
@@ -33,7 +33,9 @@ rm -rf metastore warehouse
function run_test() {
echo "Running test: $1"
- SPARK_TESTING=1 $FWDIR/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+
+ SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+
FAILED=$((PIPESTATUS[0]||$FAILED))
# Fail and exit on the first test failure.
@@ -48,6 +50,8 @@ function run_test() {
echo "Running PySpark tests. Output is in python/unit-tests.log."
+export PYSPARK_PYTHON="python"
+
# Try to test with Python 2.6, since that's the minimum version that we support:
if [ $(which python2.6) ]; then
export PYSPARK_PYTHON="python2.6"
diff --git a/repl/pom.xml b/repl/pom.xml
index 68f4504450778..fcc5f90d870e8 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
index 910b31d209e13..7667a9c11979e 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -14,6 +14,8 @@ import scala.reflect.internal.util.Position
import scala.util.control.Exception.ignoring
import scala.tools.nsc.util.stackTraceString
+import org.apache.spark.SPARK_VERSION
+
/**
* Machinery for the asynchronous initialization of the repl.
*/
@@ -26,9 +28,9 @@ trait SparkILoopInit {
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
- /___/ .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT
+ /___/ .__/\_,_/_/ /_/\_\ version %s
/_/
-""")
+""".format(SPARK_VERSION))
import Properties._
val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
versionString, javaVmName, javaVersion)
diff --git a/sbin/slaves.sh b/sbin/slaves.sh
index f89547fef9e46..1d4dc5edf9858 100755
--- a/sbin/slaves.sh
+++ b/sbin/slaves.sh
@@ -36,29 +36,29 @@ if [ $# -le 0 ]; then
exit 1
fi
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
. "$sbin/spark-config.sh"
# If the slaves file is specified in the command line,
# then it takes precedence over the definition in
# spark-env.sh. Save it here.
-HOSTLIST=$SPARK_SLAVES
+HOSTLIST="$SPARK_SLAVES"
# Check if --config is passed as an argument. It is an optional parameter.
# Exit if the argument is not a directory.
if [ "$1" == "--config" ]
then
shift
- conf_dir=$1
+ conf_dir="$1"
if [ ! -d "$conf_dir" ]
then
echo "ERROR : $conf_dir is not a directory"
echo $usage
exit 1
else
- export SPARK_CONF_DIR=$conf_dir
+ export SPARK_CONF_DIR="$conf_dir"
fi
shift
fi
@@ -79,7 +79,7 @@ if [ "$SPARK_SSH_OPTS" = "" ]; then
fi
for slave in `cat "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do
- ssh $SPARK_SSH_OPTS $slave $"${@// /\\ }" \
+ ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \
2>&1 | sed "s/^/$slave: /" &
if [ "$SPARK_SLAVE_SLEEP" != "" ]; then
sleep $SPARK_SLAVE_SLEEP
diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh
index 5c87da5815b64..2718d6cba1c9a 100755
--- a/sbin/spark-config.sh
+++ b/sbin/spark-config.sh
@@ -21,19 +21,19 @@
# resolve links - $0 may be a softlink
this="${BASH_SOURCE-$0}"
-common_bin=$(cd -P -- "$(dirname -- "$this")" && pwd -P)
+common_bin="$(cd -P -- "$(dirname -- "$this")" && pwd -P)"
script="$(basename -- "$this")"
this="$common_bin/$script"
# convert relative path to absolute path
-config_bin=`dirname "$this"`
-script=`basename "$this"`
-config_bin=`cd "$config_bin"; pwd`
+config_bin="`dirname "$this"`"
+script="`basename "$this"`"
+config_bin="`cd "$config_bin"; pwd`"
this="$config_bin/$script"
-export SPARK_PREFIX=`dirname "$this"`/..
-export SPARK_HOME=${SPARK_PREFIX}
+export SPARK_PREFIX="`dirname "$this"`"/..
+export SPARK_HOME="${SPARK_PREFIX}"
export SPARK_CONF_DIR="$SPARK_HOME/conf"
# Add the PySpark classes to the PYTHONPATH:
-export PYTHONPATH=$SPARK_HOME/python:$PYTHONPATH
-export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH
+export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH"
+export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh
index 9032f23ea8eff..bd476b400e1c3 100755
--- a/sbin/spark-daemon.sh
+++ b/sbin/spark-daemon.sh
@@ -37,8 +37,8 @@ if [ $# -le 1 ]; then
exit 1
fi
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
. "$sbin/spark-config.sh"
@@ -50,14 +50,14 @@ sbin=`cd "$sbin"; pwd`
if [ "$1" == "--config" ]
then
shift
- conf_dir=$1
+ conf_dir="$1"
if [ ! -d "$conf_dir" ]
then
echo "ERROR : $conf_dir is not a directory"
echo $usage
exit 1
else
- export SPARK_CONF_DIR=$conf_dir
+ export SPARK_CONF_DIR="$conf_dir"
fi
shift
fi
@@ -100,12 +100,12 @@ if [ "$SPARK_LOG_DIR" = "" ]; then
export SPARK_LOG_DIR="$SPARK_HOME/logs"
fi
mkdir -p "$SPARK_LOG_DIR"
-touch $SPARK_LOG_DIR/.spark_test > /dev/null 2>&1
+touch "$SPARK_LOG_DIR"/.spark_test > /dev/null 2>&1
TEST_LOG_DIR=$?
if [ "${TEST_LOG_DIR}" = "0" ]; then
- rm -f $SPARK_LOG_DIR/.spark_test
+ rm -f "$SPARK_LOG_DIR"/.spark_test
else
- chown $SPARK_IDENT_STRING $SPARK_LOG_DIR
+ chown "$SPARK_IDENT_STRING" "$SPARK_LOG_DIR"
fi
if [ "$SPARK_PID_DIR" = "" ]; then
@@ -113,8 +113,8 @@ if [ "$SPARK_PID_DIR" = "" ]; then
fi
# some variables
-log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out
-pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid
+log="$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out"
+pid="$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid"
# Set default scheduling priority
if [ "$SPARK_NICENESS" = "" ]; then
@@ -136,7 +136,7 @@ case $startStop in
fi
if [ "$SPARK_MASTER" != "" ]; then
- echo rsync from $SPARK_MASTER
+ echo rsync from "$SPARK_MASTER"
rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' $SPARK_MASTER/ "$SPARK_HOME"
fi
diff --git a/sbin/spark-executor b/sbin/spark-executor
index 3621321a9bc8d..674ce906d9421 100755
--- a/sbin/spark-executor
+++ b/sbin/spark-executor
@@ -17,10 +17,10 @@
# limitations under the License.
#
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
-export PYTHONPATH=$FWDIR/python:$PYTHONPATH
-export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH
+export PYTHONPATH="$FWDIR/python:$PYTHONPATH"
+export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
echo "Running spark-executor with framework dir = $FWDIR"
-exec $FWDIR/bin/spark-class org.apache.spark.executor.MesosExecutorBackend
+exec "$FWDIR"/bin/spark-class org.apache.spark.executor.MesosExecutorBackend
diff --git a/sbin/start-all.sh b/sbin/start-all.sh
index 5c89ab4d86b3a..1baf57cea09ee 100755
--- a/sbin/start-all.sh
+++ b/sbin/start-all.sh
@@ -21,8 +21,8 @@
# Starts the master on this node.
# Starts a worker on each node specified in conf/slaves
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
TACHYON_STR=""
diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh
index 580ab471b8a79..7172ad15d88fc 100755
--- a/sbin/start-history-server.sh
+++ b/sbin/start-history-server.sh
@@ -24,8 +24,8 @@
# Use the SPARK_HISTORY_OPTS environment variable to set history server configuration.
#
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
. "$sbin/spark-config.sh"
. "$SPARK_PREFIX/bin/load-spark-env.sh"
diff --git a/sbin/start-master.sh b/sbin/start-master.sh
index c5c02491f78e1..17fff58f4f768 100755
--- a/sbin/start-master.sh
+++ b/sbin/start-master.sh
@@ -19,8 +19,8 @@
# Starts the master on the machine this script is executed on.
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
START_TACHYON=false
diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh
index b563400dc24f3..2fc35309f4ca5 100755
--- a/sbin/start-slave.sh
+++ b/sbin/start-slave.sh
@@ -20,7 +20,7 @@
# Usage: start-slave.sh
# where is like "spark://localhost:7077"
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
"$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@"
diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh
index 4912d0c0c7dfd..ba1a84abc1fef 100755
--- a/sbin/start-slaves.sh
+++ b/sbin/start-slaves.sh
@@ -17,8 +17,8 @@
# limitations under the License.
#
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
START_TACHYON=false
@@ -46,11 +46,11 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then
fi
if [ "$SPARK_MASTER_IP" = "" ]; then
- SPARK_MASTER_IP=`hostname`
+ SPARK_MASTER_IP="`hostname`"
fi
if [ "$START_TACHYON" == "true" ]; then
- "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP
+ "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP"
# set -t so we can call sudo
SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/../tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1
@@ -58,12 +58,12 @@ fi
# Launch the slaves
if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
- exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT
+ exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT"
else
if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then
SPARK_WORKER_WEBUI_PORT=8081
fi
for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
- "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i ))
+ "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i ))
done
fi
diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh
index c519a77df4a14..4ce40fe750384 100755
--- a/sbin/start-thriftserver.sh
+++ b/sbin/start-thriftserver.sh
@@ -24,7 +24,7 @@
set -o posix
# Figure out where Spark is installed
-FWDIR="$(cd `dirname $0`/..; pwd)"
+FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2"
CLASS_NOT_FOUND_EXIT_STATUS=1
@@ -38,10 +38,10 @@ function usage {
pattern+="\|======="
pattern+="\|--help"
- $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
+ "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
echo
echo "Thrift server options:"
- $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2
+ "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2
}
if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
@@ -49,7 +49,7 @@ if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
exit 0
fi
-source $FWDIR/bin/utils.sh
+source "$FWDIR"/bin/utils.sh
SUBMIT_USAGE_FUNCTION=usage
gatherSparkSubmitOpts "$@"
diff --git a/sbin/stop-all.sh b/sbin/stop-all.sh
index 60b358d374565..298c6a9859795 100755
--- a/sbin/stop-all.sh
+++ b/sbin/stop-all.sh
@@ -21,8 +21,8 @@
# Run this on the master nde
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
# Load the Spark configuration
. "$sbin/spark-config.sh"
diff --git a/sbin/stop-history-server.sh b/sbin/stop-history-server.sh
index c0034ad641cbe..6e6056359510f 100755
--- a/sbin/stop-history-server.sh
+++ b/sbin/stop-history-server.sh
@@ -19,7 +19,7 @@
# Stops the history server on the machine this script is executed on.
-sbin=`dirname "$0"`
-sbin=`cd "$sbin"; pwd`
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.history.HistoryServer 1
diff --git a/sbt/sbt b/sbt/sbt
index 1b1aa1483a829..c172fa74bc771 100755
--- a/sbt/sbt
+++ b/sbt/sbt
@@ -3,32 +3,32 @@
# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so
# that we can run Hive to generate the golden answer. This is not required for normal development
# or testing.
-for i in $HIVE_HOME/lib/*
-do HADOOP_CLASSPATH=$HADOOP_CLASSPATH:$i
+for i in "$HIVE_HOME"/lib/*
+do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i"
done
export HADOOP_CLASSPATH
realpath () {
(
- TARGET_FILE=$1
+ TARGET_FILE="$1"
- cd $(dirname $TARGET_FILE)
- TARGET_FILE=$(basename $TARGET_FILE)
+ cd "$(dirname "$TARGET_FILE")"
+ TARGET_FILE="$(basename "$TARGET_FILE")"
COUNT=0
while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ]
do
- TARGET_FILE=$(readlink $TARGET_FILE)
- cd $(dirname $TARGET_FILE)
- TARGET_FILE=$(basename $TARGET_FILE)
+ TARGET_FILE="$(readlink "$TARGET_FILE")"
+ cd $(dirname "$TARGET_FILE")
+ TARGET_FILE="$(basename $TARGET_FILE)"
COUNT=$(($COUNT + 1))
done
- echo $(pwd -P)/$TARGET_FILE
+ echo "$(pwd -P)/"$TARGET_FILE""
)
}
-. $(dirname $(realpath $0))/sbt-launch-lib.bash
+. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash
declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy"
diff --git a/sbt/sbt-launch-lib.bash b/sbt/sbt-launch-lib.bash
index c91fecf024ad4..7f05d2ef491a3 100755
--- a/sbt/sbt-launch-lib.bash
+++ b/sbt/sbt-launch-lib.bash
@@ -7,7 +7,7 @@
# TODO - Should we merge the main SBT script with this library?
if test -z "$HOME"; then
- declare -r script_dir="$(dirname $script_path)"
+ declare -r script_dir="$(dirname "$script_path")"
else
declare -r script_dir="$HOME/.sbt"
fi
@@ -46,20 +46,20 @@ acquire_sbt_jar () {
if [[ ! -f "$sbt_jar" ]]; then
# Download sbt launch jar if it hasn't been downloaded yet
- if [ ! -f ${JAR} ]; then
+ if [ ! -f "${JAR}" ]; then
# Download
printf "Attempting to fetch sbt\n"
- JAR_DL=${JAR}.part
+ JAR_DL="${JAR}.part"
if hash curl 2>/dev/null; then
- (curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR}
+ (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
elif hash wget 2>/dev/null; then
- (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR}
+ (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
exit -1
fi
fi
- if [ ! -f ${JAR} ]; then
+ if [ ! -f "${JAR}" ]; then
# We failed to download
printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n"
exit -1
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 830711a46a35b..0d756f873e486 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
old mode 100644
new mode 100755
index a88bd859fc85e..ca69531c69a77
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -73,6 +73,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val ASC = Keyword("ASC")
protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AVG = Keyword("AVG")
+ protected val BETWEEN = Keyword("BETWEEN")
protected val BY = Keyword("BY")
protected val CACHE = Keyword("CACHE")
protected val CAST = Keyword("CAST")
@@ -81,6 +82,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val DISTINCT = Keyword("DISTINCT")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
+ protected val LAST = Keyword("LAST")
protected val FROM = Keyword("FROM")
protected val FULL = Keyword("FULL")
protected val GROUP = Keyword("GROUP")
@@ -124,6 +126,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val SUBSTR = Keyword("SUBSTR")
protected val SUBSTRING = Keyword("SUBSTRING")
protected val SQRT = Keyword("SQRT")
+ protected val ABS = Keyword("ABS")
// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
@@ -272,6 +275,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } |
termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } |
termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } |
+ termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ {
+ case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu))
+ } |
termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } |
@@ -311,6 +317,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
} |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
+ LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
@@ -326,6 +333,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l)
} |
SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } |
+ ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } |
ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
}
@@ -349,16 +357,25 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
expression ~ "[" ~ expression <~ "]" ^^ {
case base ~ _ ~ ordinal => GetItem(base, ordinal)
} |
+ (expression <~ ".") ~ ident ^^ {
+ case base ~ fieldName => GetField(base, fieldName)
+ } |
TRUE ^^^ Literal(true, BooleanType) |
FALSE ^^^ Literal(false, BooleanType) |
cast |
"(" ~> expression <~ ")" |
function |
"-" ~> literal ^^ UnaryMinus |
+ dotExpressionHeader |
ident ^^ UnresolvedAttribute |
"*" ^^^ Star(None) |
literal
+ protected lazy val dotExpressionHeader: Parser[Expression] =
+ (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
+ case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", ""))
+ }
+
protected lazy val dataType: Parser[DataType] =
STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
}
@@ -372,7 +389,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {
delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
- ",", ";", "%", "{", "}", ":", "[", "]"
+ ",", ";", "%", "{", "}", ":", "[", "]", "."
)
override lazy val token: Parser[Token] = (
@@ -393,7 +410,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {
| failure("illegal character")
)
- override def identChar = letter | elem('_') | elem('.')
+ override def identChar = letter | elem('_')
override def whitespace: Parser[Any] = rep(
whitespaceChar
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index d6758eb5b6a32..bd8131c9af6e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -26,10 +26,22 @@ object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
val numericPrecedence =
- Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
- // Boolean is only wider than Void
- val booleanPrecedence = Seq(NullType, BooleanType)
- val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
+ val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
+
+ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
+ val valueTypes = Seq(t1, t2).filter(t => t != NullType)
+ if (valueTypes.distinct.size > 1) {
+ // Try and find a promotion rule that contains both types in question.
+ val applicableConversion =
+ HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
+
+ // If found return the widest common type, otherwise None
+ applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+ } else {
+ Some(if (valueTypes.size == 0) NullType else valueTypes.head)
+ }
+ }
}
/**
@@ -53,17 +65,6 @@ trait HiveTypeCoercion {
Division ::
Nil
- trait TypeWidening {
- def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
- // Try and find a promotion rule that contains both types in question.
- val applicableConversion =
- HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
-
- // If found return the widest common type, otherwise None
- applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
- }
- }
-
/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
@@ -144,7 +145,8 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
- object WidenTypes extends Rule[LogicalPlan] with TypeWidening {
+ object WidenTypes extends Rule[LogicalPlan] {
+ import HiveTypeCoercion._
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
@@ -352,7 +354,9 @@ trait HiveTypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
- object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
+ object CaseWhenCoercion extends Rule[LogicalPlan] {
+ import HiveTypeCoercion._
+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
old mode 100644
new mode 100755
index f44521d6381c9..deb622c39faf5
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -132,6 +132,7 @@ package object dsl {
def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd)
def avg(e: Expression) = Average(e)
def first(e: Expression) = First(e)
+ def last(e: Expression) = Last(e)
def min(e: Expression) = Min(e)
def max(e: Expression) = Max(e)
def upper(e: Expression) = Upper(e)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
index 75ea0e8459df8..088f11ee4aa53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
@@ -227,7 +227,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
new SpecificMutableRow(newValues)
}
- override def update(ordinal: Int, value: Any): Unit = values(ordinal).update(value)
+ override def update(ordinal: Int, value: Any): Unit = {
+ if (value == null) setNullAt(ordinal) else values(ordinal).update(value)
+ }
override def iterator: Iterator[Any] = values.map(_.boxed).iterator
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
old mode 100644
new mode 100755
index 15560a2a933ad..1b4d892625dbb
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -344,6 +344,21 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance() = new FirstFunction(child, this)
}
+case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ override def references = child.references
+ override def nullable = true
+ override def dataType = child.dataType
+ override def toString = s"LAST($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialLast = Alias(Last(child), "PartialLast")()
+ SplitEvaluation(
+ Last(partialLast.toAttribute),
+ partialLast :: Nil)
+ }
+ override def newInstance() = new LastFunction(child, this)
+}
+
case class AverageFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
@@ -489,3 +504,16 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: Row): Any = result
}
+
+case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ var result: Any = null
+
+ override def update(input: Row): Unit = {
+ result = input
+ }
+
+ override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row])
+ else null
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index f988fb010b107..fe825fdcdae37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types._
+import scala.math.pow
case class UnaryMinus(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
@@ -129,3 +130,17 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
override def toString = s"MaxOf($left, $right)"
}
+
+/**
+ * A function that get the absolute value of the numeric value.
+ */
+case class Abs(child: Expression) extends UnaryExpression {
+ type EvaluatedType = Any
+
+ def dataType = child.dataType
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"Abs($child)"
+
+ override def eval(input: Row): Any = n1(child, input, _.abs(_))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 1313ccd120c1f..329af332d0fa1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -265,12 +265,13 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
false
} else {
val allCondBooleans = predicates.forall(_.dataType == BooleanType)
- val dataTypesEqual = values.map(_.dataType).distinct.size <= 1
+ // both then and else val should be considered.
+ val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
allCondBooleans && dataTypesEqual
}
}
- /** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */
+ /** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
val len = branchesArr.length
var i = 0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index f81d9111945f5..bae491f07c13f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -104,11 +104,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
- a.dataType match {
- case StructType(fields) =>
- Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
- case _ => None // Don't know how to resolve these field references
- }
+ Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
case Seq() => None // No matches.
case ambiguousReferences =>
throw new TreeNodeException(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index b9e0f8e9dcc5f..ba8b853b6f99e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -23,20 +23,20 @@ import org.apache.spark.sql.catalyst.types._
class HiveTypeCoercionSuite extends FunSuite {
- val rules = new HiveTypeCoercion { }
- import rules._
-
- test("tightest common bound for numeric and boolean types") {
+ test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
- var found = WidenTypes.findTightestCommonType(t1, t2)
+ var found = HiveTypeCoercion.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
- found = WidenTypes.findTightestCommonType(t2, t1)
+ found = HiveTypeCoercion.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}
+ // Null
+ widenTest(NullType, NullType, Some(NullType))
+
// Boolean
widenTest(NullType, BooleanType, Some(BooleanType))
widenTest(BooleanType, BooleanType, Some(BooleanType))
@@ -60,12 +60,28 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(DoubleType, DoubleType, Some(DoubleType))
// Integral mixed with floating point.
- widenTest(NullType, FloatType, Some(FloatType))
- widenTest(NullType, DoubleType, Some(DoubleType))
widenTest(IntegerType, FloatType, Some(FloatType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))
+
+ // StringType
+ widenTest(NullType, StringType, Some(StringType))
+ widenTest(StringType, StringType, Some(StringType))
+ widenTest(IntegerType, StringType, None)
+ widenTest(LongType, StringType, None)
+
+ // TimestampType
+ widenTest(NullType, TimestampType, Some(TimestampType))
+ widenTest(TimestampType, TimestampType, Some(TimestampType))
+ widenTest(IntegerType, TimestampType, None)
+ widenTest(StringType, TimestampType, None)
+
+ // ComplexType
+ widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false)))
+ widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
+ widenTest(StringType, MapType(IntegerType, StringType, true), None)
+ widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index c8016e41256d5..bd110218d34f7 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 4137ac7663739..f6f4cf3b80d41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -53,7 +53,7 @@ private[spark] object SQLConf {
*
* SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads).
*/
-trait SQLConf {
+private[sql] trait SQLConf {
import SQLConf._
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5acb45c155ba5..a2f334aab9fdf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -246,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
- catalog.registerTable(None, tableName, rdd.logicalPlan)
+ catalog.registerTable(None, tableName, rdd.queryExecution.analyzed)
}
/**
@@ -411,7 +411,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
- def simpleString: String =
+ def simpleString: String =
s"""== Physical Plan ==
|${stringOrError(executedPlan)}
"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 33b2ed1b3a399..d2ceb4a2b0b25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -428,7 +428,8 @@ class SchemaRDD(
*/
private def applySchema(rdd: RDD[Row]): SchemaRDD = {
new SchemaRDD(sqlContext,
- SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd))(sqlContext))
+ SparkLogicalPlan(
+ ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext))
}
// =======================================================================
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 0ea1105f082a4..595b4aa36eae3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -30,7 +30,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag}
/**
* Functions for registering scala lambda functions as UDFs in a SQLContext.
*/
-protected[sql] trait UDFRegistration {
+private[sql] trait UDFRegistration {
self: SQLContext =>
private[spark] def registerPython(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index dc668e7dc934c..6eab2f23c18e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
-object InMemoryRelation {
+private[sql] object InMemoryRelation {
def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation =
new InMemoryRelation(child.output, useCompression, batchSize, child)()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 4802e40595807..927f40063e47e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -36,25 +36,23 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
override def outputPartitioning = newPartitioning
- def output = child.output
+ override def output = child.output
/** We must copy rows when sort based shuffle is on */
protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
- def execute() = attachTree(this , "execute") {
+ override def execute() = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
- val rdd = child.execute().mapPartitions { iter =>
- if (sortBasedShuffleOn) {
- @transient val hashExpressions =
- newProjection(expressions, child.output)
-
+ val rdd = if (sortBasedShuffleOn) {
+ child.execute().mapPartitions { iter =>
+ val hashExpressions = newProjection(expressions, child.output)
iter.map(r => (hashExpressions(r), r.copy()))
- } else {
- @transient val hashExpressions =
- newMutableProjection(expressions, child.output)()
-
+ }
+ } else {
+ child.execute().mapPartitions { iter =>
+ val hashExpressions = newMutableProjection(expressions, child.output)()
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
@@ -65,17 +63,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
- // TODO: RangePartitioner should take an Ordering.
- implicit val ordering = new RowOrdering(sortingExpressions, child.output)
-
- val rdd = child.execute().mapPartitions { iter =>
- if (sortBasedShuffleOn) {
- iter.map(row => (row.copy(), null))
- } else {
+ val rdd = if (sortBasedShuffleOn) {
+ child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
+ } else {
+ child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Row, Null](null, null)
iter.map(row => mutablePair.update(row, null))
}
}
+
+ // TODO: RangePartitioner should take an Ordering.
+ implicit val ordering = new RowOrdering(sortingExpressions, child.output)
+
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
@@ -83,10 +82,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
shuffled.map(_._1)
case SinglePartition =>
- val rdd = child.execute().mapPartitions { iter =>
- if (sortBasedShuffleOn) {
- iter.map(r => (null, r.copy()))
- } else {
+ val rdd = if (sortBasedShuffleOn) {
+ child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) }
+ } else {
+ child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Null, Row]()
iter.map(r => mutablePair.update(null, r))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4abda21ffec96..cac376608be29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.execution
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
+import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.{HashPartitioner, SparkConf}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, SinglePartition, UnspecifiedDistribution}
import org.apache.spark.util.MutablePair
/**
@@ -96,7 +96,11 @@ case class Limit(limit: Int, child: SparkPlan)
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again
+ /** We must copy rows when sort based shuffle is on */
+ private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+
override def output = child.output
+ override def outputPartitioning = SinglePartition
/**
* A custom implementation modeled after the take function on RDDs but which never runs any job
@@ -143,9 +147,15 @@ case class Limit(limit: Int, child: SparkPlan)
}
override def execute() = {
- val rdd = child.execute().mapPartitions { iter =>
- val mutablePair = new MutablePair[Boolean, Row]()
- iter.take(limit).map(row => mutablePair.update(false, row))
+ val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) {
+ child.execute().mapPartitions { iter =>
+ iter.take(limit).map(row => (false, row.copy()))
+ }
+ } else {
+ child.execute().mapPartitions { iter =>
+ val mutablePair = new MutablePair[Boolean, Row]()
+ iter.take(limit).map(row => mutablePair.update(false, row))
+ }
}
val part = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
@@ -164,6 +174,7 @@ case class Limit(limit: Int, child: SparkPlan)
case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
override def output = child.output
+ override def outputPartitioning = SinglePartition
val ordering = new RowOrdering(sortOrder, child.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 1c0b03c684f10..70062eae3b7ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -125,38 +125,31 @@ private[sql] object JsonRDD extends Logging {
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
- // Try and find a promotion rule that contains both types in question.
- val applicableConversion = HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p
- .contains(t2))
-
- // If found return the widest common type, otherwise None
- val returnType = applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
-
- if (returnType.isDefined) {
- returnType.get
- } else {
- // t1 or t2 is a StructType, ArrayType, or an unexpected type.
- (t1, t2) match {
- case (other: DataType, NullType) => other
- case (NullType, other: DataType) => other
- case (StructType(fields1), StructType(fields2)) => {
- val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
- case (name, fieldTypes) => {
- val dataType = fieldTypes.map(field => field.dataType).reduce(
- (type1: DataType, type2: DataType) => compatibleType(type1, type2))
- StructField(name, dataType, true)
+ HiveTypeCoercion.findTightestCommonType(t1, t2) match {
+ case Some(commonType) => commonType
+ case None =>
+ // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+ (t1, t2) match {
+ case (other: DataType, NullType) => other
+ case (NullType, other: DataType) => other
+ case (StructType(fields1), StructType(fields2)) => {
+ val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
+ case (name, fieldTypes) => {
+ val dataType = fieldTypes.map(field => field.dataType).reduce(
+ (type1: DataType, type2: DataType) => compatibleType(type1, type2))
+ StructField(name, dataType, true)
+ }
}
+ StructType(newFields.toSeq.sortBy {
+ case StructField(name, _, _) => name
+ })
}
- StructType(newFields.toSeq.sortBy {
- case StructField(name, _, _) => name
- })
+ case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
+ ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+ // TODO: We should use JsonObjectStringType to mark that values of field will be
+ // strings and every string is a Json object.
+ case (_, _) => StringType
}
- case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
- ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
- // TODO: We should use JsonObjectStringType to mark that values of field will be
- // strings and every string is a Json object.
- case (_, _) => StringType
- }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 9fd6aed402838..2fc7e1cf23ab7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -382,7 +382,7 @@ private[parquet] class CatalystPrimitiveConverter(
parent.updateLong(fieldIndex, value)
}
-object CatalystArrayConverter {
+private[parquet] object CatalystArrayConverter {
val INITIAL_ARRAY_SIZE = 20
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index fe28e0d7269e0..7c83f1cad7d71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
-object ParquetFilters {
+private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
// set this to false if pushdown should be disabled
val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 1a6a6c17473a3..d001abb7e1fcc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.test._
/* Implicits */
@@ -133,6 +135,18 @@ class DslQuerySuite extends QueryTest {
mapData.take(1).toSeq)
}
+ test("SPARK-3395 limit distinct") {
+ val filtered = TestData.testData2
+ .distinct()
+ .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending))
+ .limit(1)
+ .registerTempTable("onerow")
+ checkAnswer(
+ sql("select * from onerow inner join testData2 on onerow.a = testData2.a"),
+ (1, 1, 1, 1) ::
+ (1, 1, 1, 2) :: Nil)
+ }
+
test("average") {
checkAnswer(
testData2.groupBy()(avg('a)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 651cb735ab7d9..811319e0a6601 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
class RowSuite extends FunSuite {
@@ -43,4 +43,10 @@ class RowSuite extends FunSuite {
assert(expected.getBoolean(2) === actual2.getBoolean(2))
assert(expected(3) === actual2(3))
}
+
+ test("SpecificMutableRow.update with null") {
+ val row = new SpecificMutableRow(Seq(IntegerType))
+ row(0) = null
+ assert(row.isNullAt(0))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 1ac205937714c..514ac543df92a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -41,6 +41,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
+ test("SPARK-3176 Added Parser of SQL ABS()") {
+ checkAnswer(
+ sql("SELECT ABS(-1.3)"),
+ 1.3)
+ checkAnswer(
+ sql("SELECT ABS(0.0)"),
+ 0.0)
+ checkAnswer(
+ sql("SELECT ABS(2.5)"),
+ 2.5)
+ }
+
+ test("SPARK-3176 Added Parser of SQL LAST()") {
+ checkAnswer(
+ sql("SELECT LAST(n) FROM lowerCaseData"),
+ 4)
+ }
+
+
test("SPARK-2041 column name equals tablename") {
checkAnswer(
sql("SELECT tableName FROM tableName"),
@@ -53,14 +72,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
)
}
-
+
test("SQRT with automatic string casts") {
checkAnswer(
sql("SELECT SQRT(CAST(key AS STRING)) FROM testData"),
(1 to 100).map(x => Row(math.sqrt(x.toDouble))).toSeq
)
}
-
+
test("SPARK-2407 Added Parser of SQL SUBSTR()") {
checkAnswer(
sql("SELECT substr(tableName, 1, 2) FROM tableName"),
@@ -359,6 +378,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(null, null, 6, "F") :: Nil)
}
+ test("SPARK-3349 partitioning after limit") {
+ /*
+ sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
+ .limit(2)
+ .registerTempTable("subset1")
+ sql("SELECT DISTINCT n FROM lowerCaseData")
+ .limit(2)
+ .registerTempTable("subset2")
+ checkAnswer(
+ sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"),
+ (3, "c", 3) ::
+ (4, "d", 4) :: Nil)
+ checkAnswer(
+ sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"),
+ (1, "a", 1) ::
+ (2, "b", 2) :: Nil)
+ */
+ }
+
test("mixed-case keywords") {
checkAnswer(
sql(
@@ -580,4 +618,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(3, null) ::
(4, 2147483644) :: Nil)
}
+
+ test("SPARK-3423 BETWEEN") {
+ checkAnswer(
+ sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"),
+ Seq((5, "5"), (6, "6"), (7, "7"))
+ )
+
+ checkAnswer(
+ sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"),
+ Seq((7, "7"))
+ )
+
+ checkAnswer(
+ sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"),
+ Seq()
+ )
+
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 05513a127150c..301d482d27d86 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -581,4 +581,18 @@ class JsonSuite extends QueryTest {
"this is a simple string.") :: Nil
)
}
+
+ test("SPARK-2096 Correctly parse dot notations") {
+ val jsonSchemaRDD = jsonRDD(complexFieldAndType2)
+ jsonSchemaRDD.registerTempTable("jsonTable")
+
+ checkAnswer(
+ sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
+ (true, "str1") :: Nil
+ )
+ checkAnswer(
+ sql("select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] from jsonTable"),
+ ("str2", 6) :: Nil
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
index a88310b5f1b46..b3f95f08e8044 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
@@ -82,4 +82,30 @@ object TestJsonData {
"""{"c":[33, 44]}""" ::
"""{"d":{"field":true}}""" ::
"""{"e":"str"}""" :: Nil)
+
+ val complexFieldAndType2 =
+ TestSQLContext.sparkContext.parallelize(
+ """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
+ "complexArrayOfStruct": [
+ {
+ "field1": [
+ {
+ "inner1": "str1"
+ },
+ {
+ "inner2": ["str2", "str22"]
+ }],
+ "field2": [[1, 2], [3, 4]]
+ },
+ {
+ "field1": [
+ {
+ "inner2": ["str3", "str33"]
+ },
+ {
+ "inner1": "str4"
+ }],
+ "field2": [[5, 6], [7, 8]]
+ }]
+ }""" :: Nil)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 42923b6a288d9..b0a06cd3ca090 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,19 +17,14 @@
package org.apache.spark.sql.parquet
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.mapreduce.Job
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
-
import parquet.hadoop.ParquetFileWriter
import parquet.hadoop.util.ContextUtil
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.mapreduce.Job
-
-import org.apache.spark.SparkContext
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser}
-import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType}
+import org.apache.spark.sql.catalyst.types.IntegerType
import org.apache.spark.sql.catalyst.util.getTempFilePath
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
@@ -87,11 +82,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
var testRDD: SchemaRDD = null
- // TODO: remove this once SqlParser can parse nested select statements
- var nestedParserSqlContext: NestedParserSQLContext = null
-
override def beforeAll() {
- nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext)
ParquetTestData.writeFile()
ParquetTestData.writeFilterFile()
ParquetTestData.writeNestedFile1()
@@ -718,11 +709,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("Projection in addressbook") {
- val data = nestedParserSqlContext
- .parquetFile(ParquetTestData.testNestedDir1.toString)
- .toSchemaRDD
+ val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD
data.registerTempTable("data")
- val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data")
+ val query = sql("SELECT owner, contacts[1].name FROM data")
val tmp = query.collect()
assert(tmp.size === 2)
assert(tmp(0).size === 2)
@@ -733,21 +722,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("Simple query on nested int data") {
- val data = nestedParserSqlContext
- .parquetFile(ParquetTestData.testNestedDir2.toString)
- .toSchemaRDD
+ val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD
data.registerTempTable("data")
- val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect()
+ val result1 = sql("SELECT entries[0].value FROM data").collect()
assert(result1.size === 1)
assert(result1(0).size === 1)
assert(result1(0)(0) === 2.5)
- val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect()
+ val result2 = sql("SELECT entries[0] FROM data").collect()
assert(result2.size === 1)
val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]]
assert(subresult1.size === 2)
assert(subresult1(0) === 2.5)
assert(subresult1(1) === false)
- val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect()
+ val result3 = sql("SELECT outerouter FROM data").collect()
val subresult2 = result3(0)(0)
.asInstanceOf[CatalystConverter.ArrayScalaType[_]](0)
.asInstanceOf[CatalystConverter.ArrayScalaType[_]]
@@ -760,19 +747,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("nested structs") {
- val data = nestedParserSqlContext
- .parquetFile(ParquetTestData.testNestedDir3.toString)
+ val data = parquetFile(ParquetTestData.testNestedDir3.toString)
.toSchemaRDD
data.registerTempTable("data")
- val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect()
+ val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect()
assert(result1.size === 1)
assert(result1(0).size === 1)
assert(result1(0)(0) === false)
- val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect()
+ val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect()
assert(result2.size === 1)
assert(result2(0).size === 1)
assert(result2(0)(0) === true)
- val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect()
+ val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect()
assert(result3.size === 1)
assert(result3(0).size === 1)
assert(result3(0)(0) === false)
@@ -796,11 +782,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("map with struct values") {
- val data = nestedParserSqlContext
- .parquetFile(ParquetTestData.testNestedDir4.toString)
- .toSchemaRDD
+ val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD
data.registerTempTable("mapTable")
- val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect()
+ val result1 = sql("SELECT data2 FROM mapTable").collect()
assert(result1.size === 1)
val entry1 = result1(0)(0)
.asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]]
@@ -814,7 +798,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(entry2 != null)
assert(entry2(0) === 49)
assert(entry2(1) === null)
- val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect()
+ val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect()
assert(result2.size === 1)
assert(result2(0)(0) === 42.toLong)
assert(result2(0)(1) === "the answer")
@@ -825,15 +809,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
// has no effect in this test case
val tmpdir = Utils.createTempDir()
Utils.deleteRecursively(tmpdir)
- val result = nestedParserSqlContext
- .parquetFile(ParquetTestData.testNestedDir1.toString)
- .toSchemaRDD
+ val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD
result.saveAsParquetFile(tmpdir.toString)
- nestedParserSqlContext
- .parquetFile(tmpdir.toString)
+ parquetFile(tmpdir.toString)
.toSchemaRDD
.registerTempTable("tmpcopy")
- val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect()
+ val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect()
assert(tmpdata.size === 2)
assert(tmpdata(0).size === 2)
assert(tmpdata(0)(0) === "Julien Le Dem")
@@ -844,20 +825,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("Writing out Map and reading it back in") {
- val data = nestedParserSqlContext
- .parquetFile(ParquetTestData.testNestedDir4.toString)
- .toSchemaRDD
+ val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD
val tmpdir = Utils.createTempDir()
Utils.deleteRecursively(tmpdir)
data.saveAsParquetFile(tmpdir.toString)
- nestedParserSqlContext
- .parquetFile(tmpdir.toString)
+ parquetFile(tmpdir.toString)
.toSchemaRDD
.registerTempTable("tmpmapcopy")
- val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect()
+ val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect()
assert(result1.size === 1)
assert(result1(0)(0) === 2)
- val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect()
+ val result2 = sql("SELECT data2 FROM tmpmapcopy").collect()
assert(result2.size === 1)
val entry1 = result2(0)(0)
.asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]]
@@ -871,42 +849,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(entry2 != null)
assert(entry2(0) === 49)
assert(entry2(1) === null)
- val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect()
+ val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect()
assert(result3.size === 1)
assert(result3(0)(0) === 42.toLong)
assert(result3(0)(1) === "the answer")
Utils.deleteRecursively(tmpdir)
}
}
-
-// TODO: the code below is needed temporarily until the standard parser is able to parse
-// nested field expressions correctly
-class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) {
- override protected[sql] val parser = new NestedSqlParser()
-}
-
-class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) {
- override def identChar = letter | elem('_')
- delimiters += (".")
-}
-
-class NestedSqlParser extends SqlParser {
- override val lexical = new NestedSqlLexical(reservedWords)
-
- override protected lazy val baseExpression: PackratParser[Expression] =
- expression ~ "[" ~ expression <~ "]" ^^ {
- case base ~ _ ~ ordinal => GetItem(base, ordinal)
- } |
- expression ~ "." ~ ident ^^ {
- case base ~ _ ~ fieldName => GetField(base, fieldName)
- } |
- TRUE ^^^ Literal(true, BooleanType) |
- FALSE ^^^ Literal(false, BooleanType) |
- cast |
- "(" ~> expression <~ ")" |
- function |
- "-" ~> literal ^^ UnaryMinus |
- ident ^^ UnresolvedAttribute |
- "*" ^^^ Star(None) |
- literal
-}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index c6f60c18804a4..124fc107cb8aa 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index f12b5a69a09f7..bd3f68d92d8c7 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -39,7 +39,9 @@ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils
/**
* Executes queries using Spark SQL, and maintains a list of handles to active queries.
*/
-class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging {
+private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
+ extends OperationManager with Logging {
+
val handleToOperation = ReflectionUtils
.getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 35cccb9e5803d..6867766bc32c2 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../../pom.xml
diff --git a/sql/hive/src/test/resources/golden/case when then 1 else null end -0-f7c7fdd35c084bc797890aa08d33693c b/sql/hive/src/test/resources/golden/case when then 1 else null end -0-f7c7fdd35c084bc797890aa08d33693c
new file mode 100644
index 0000000000000..d00491fd7e5bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then 1 else null end -0-f7c7fdd35c084bc797890aa08d33693c
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0 b/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0
new file mode 100644
index 0000000000000..d3827e75a5cad
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then 1.0 else null end -0-aeb1f906bfe92f2d406f84109301afe0
@@ -0,0 +1 @@
+1.0
diff --git a/sql/hive/src/test/resources/golden/case when then 1L else null end -0-763ae85e7a52b4cf4162d6a8931716bb b/sql/hive/src/test/resources/golden/case when then 1L else null end -0-763ae85e7a52b4cf4162d6a8931716bb
new file mode 100644
index 0000000000000..d00491fd7e5bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then 1L else null end -0-763ae85e7a52b4cf4162d6a8931716bb
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/case when then 1S else null end -0-6f5f3b3dbe9f1d1eb98443aef315b982 b/sql/hive/src/test/resources/golden/case when then 1S else null end -0-6f5f3b3dbe9f1d1eb98443aef315b982
new file mode 100644
index 0000000000000..d00491fd7e5bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then 1S else null end -0-6f5f3b3dbe9f1d1eb98443aef315b982
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/case when then 1Y else null end -0-589982a400d86157791c7216b10b6b5d b/sql/hive/src/test/resources/golden/case when then 1Y else null end -0-589982a400d86157791c7216b10b6b5d
new file mode 100644
index 0000000000000..d00491fd7e5bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then 1Y else null end -0-589982a400d86157791c7216b10b6b5d
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/case when then null else 1 end -0-48bd83660cf3ba93cdbdc24559092171 b/sql/hive/src/test/resources/golden/case when then null else 1 end -0-48bd83660cf3ba93cdbdc24559092171
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then null else 1 end -0-48bd83660cf3ba93cdbdc24559092171
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623 b/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then null else 1.0 end -0-7f5ce763801781cf568c6a31dd80b623
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/case when then null else 1L end -0-a7f1305ea4f86e596c368e35e45cc4e5 b/sql/hive/src/test/resources/golden/case when then null else 1L end -0-a7f1305ea4f86e596c368e35e45cc4e5
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then null else 1L end -0-a7f1305ea4f86e596c368e35e45cc4e5
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/case when then null else 1S end -0-dfb61969e6cb6e6dbe89225b538c8d98 b/sql/hive/src/test/resources/golden/case when then null else 1S end -0-dfb61969e6cb6e6dbe89225b538c8d98
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then null else 1S end -0-dfb61969e6cb6e6dbe89225b538c8d98
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/case when then null else 1Y end -0-7f4c32299c3738739b678ece62752a7b b/sql/hive/src/test/resources/golden/case when then null else 1Y end -0-7f4c32299c3738739b678ece62752a7b
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case when then null else 1Y end -0-7f4c32299c3738739b678ece62752a7b
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index c44eb3a63d51c..831b3c5a2cfd8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -17,11 +17,8 @@
package org.apache.spark.sql.hive.execution
-import java.io.File
-
import scala.util.Try
-import org.apache.spark.SparkException
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
@@ -507,14 +504,39 @@ class HiveQuerySuite extends HiveComparisonTest {
|WITH serdeproperties('s1'='9')
""".stripMargin)
}
- sql(s"ADD JAR $testJar")
- sql(
- """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe'
- |WITH serdeproperties('s1'='9')
- """.stripMargin)
+ /*now only verify 0.12.0, and ignore other versions due to binary compatability*/
+ if (HiveShim.version.equals("0.12.0")) {
+ sql(s"ADD JAR $testJar")
+ sql(
+ """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe'
+ |WITH serdeproperties('s1'='9')
+ """.stripMargin)
+ }
sql("DROP TABLE alter1")
}
+ case class LogEntry(filename: String, message: String)
+ case class LogFile(name: String)
+
+ test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") {
+ sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs")
+ sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles")
+
+ sql(
+ """
+ SELECT name, message
+ FROM rawLogs
+ JOIN (
+ SELECT name
+ FROM logFiles
+ ) files
+ ON rawLogs.filename = files.name
+ """).registerTempTable("boom")
+
+ // This should be successfully analyzed
+ sql("SELECT * FROM boom").queryExecution.analyzed
+ }
+
test("parse HQL set commands") {
// Adapted from its SQL counterpart.
val testKey = "spark.sql.key.usedfortestonly"
@@ -540,62 +562,67 @@ class HiveQuerySuite extends HiveComparisonTest {
val testKey = "spark.sql.key.usedfortestonly"
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
-
+ val KV = "([^=]+)=([^=]*)".r
+ def collectResults(rdd: SchemaRDD): Set[(String, String)] =
+ rdd.collect().map {
+ case Row(key: String, value: String) => key -> value
+ case Row(KV(key, value)) => key -> value
+ }.toSet
clear()
// "set" itself returns all config variables currently specified in SQLConf.
// TODO: Should we be listing the default here always? probably...
assert(sql("SET").collect().size == 0)
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(hql(s"SET $testKey=$testVal"))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(hql("SET"))
}
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Set(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
- sql(s"SET").collect().map(_.getString(0)).toSet
+ assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
+ collectResults(hql("SET"))
}
// "set key"
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(hql(s"SET $testKey"))
}
- assertResult(Array(s"$nonexistentKey=")) {
- sql(s"SET $nonexistentKey").collect().map(_.getString(0))
+ assertResult(Set(nonexistentKey -> "")) {
+ collectResults(hql(s"SET $nonexistentKey"))
}
// Assert that sql() should have the same effects as sql() by repeating the above using sql().
clear()
assert(sql("SET").collect().size == 0)
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey=$testVal").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(sql(s"SET $testKey=$testVal"))
}
assert(hiveconf.get(testKey, "") == testVal)
- assertResult(Array(s"$testKey=$testVal")) {
- sql("SET").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(sql("SET"))
}
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
- assertResult(Set(s"$testKey=$testVal", s"${testKey + testKey}=${testVal + testVal}")) {
- sql("SET").collect().map(_.getString(0)).toSet
+ assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
+ collectResults(sql("SET"))
}
- assertResult(Array(s"$testKey=$testVal")) {
- sql(s"SET $testKey").collect().map(_.getString(0))
+ assertResult(Set(testKey -> testVal)) {
+ collectResults(sql(s"SET $testKey"))
}
- assertResult(Array(s"$nonexistentKey=")) {
- sql(s"SET $nonexistentKey").collect().map(_.getString(0))
+ assertResult(Set(nonexistentKey -> "")) {
+ collectResults(sql(s"SET $nonexistentKey"))
}
clear()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
index c3c18cf8ccac3..48fffe53cf2ff 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
@@ -33,6 +33,12 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
}
}
+ val nullVal = "null"
+ baseTypes.init.foreach { i =>
+ createQueryTest(s"case when then $i else $nullVal end ", s"SELECT case when true then $i else $nullVal end FROM src limit 1")
+ createQueryTest(s"case when then $nullVal else $i end ", s"SELECT case when true then $nullVal else $i end FROM src limit 1")
+ }
+
test("[SPARK-2210] boolean cast on boolean value should be removed") {
val q = "select cast(cast(key=0 as boolean) as boolean) from src"
val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 635a9fb0d56cb..b99caf77bce28 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.hive.execution
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.{SQLConf, QueryTest}
-import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin}
-import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.hive.test.TestHive._
+case class Nested1(f1: Nested2)
+case class Nested2(f2: Nested3)
+case class Nested3(f3: Int)
+
/**
* A collection of hive query tests where we generate the answers ourselves instead of depending on
* Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
@@ -47,4 +47,11 @@ class SQLQuerySuite extends QueryTest {
GROUP BY key, value
ORDER BY value) a""").collect().toSeq)
}
+
+ test("double nested data") {
+ sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested")
+ checkAnswer(
+ sql("SELECT f1.f2.f3 FROM nested"),
+ 1)
+ }
}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index ce35520a28609..12f900c91eb98 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 101cec1c7a7c2..457e8ab28ed82 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -240,7 +240,7 @@ class StreamingContext private[streaming] (
* Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html
* @param props Props object defining creation of the actor
* @param name Name of the actor
- * @param storageLevel RDD storage level. Defaults to memory-only.
+ * @param storageLevel RDD storage level (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*
* @note An important point to note:
* Since Actor may exist outside the spark framework, It is thus user's responsibility
diff --git a/tools/pom.xml b/tools/pom.xml
index 97abb6b2b63e0..f36674476770c 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml
index 51744ece0412d..7dadbba58fd82 100644
--- a/yarn/alpha/pom.xml
+++ b/yarn/alpha/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
yarn-parent_2.10
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
index ad27a9ab781d2..acf26505e4cf9 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.yarn
import scala.collection.{Map, Set}
+import java.net.URI;
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.yarn.api._
@@ -97,7 +98,8 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC
// Users can then monitor stderr/stdout on that node if required.
appMasterRequest.setHost(Utils.localHostName())
appMasterRequest.setRpcPort(0)
- appMasterRequest.setTrackingUrl(uiAddress)
+ // remove the scheme from the url if it exists since Hadoop does not expect scheme
+ appMasterRequest.setTrackingUrl(new URI(uiAddress).getAuthority())
resourceManager.registerApplicationMaster(appMasterRequest)
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index a879c833a014f..5756263e89e21 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -189,7 +189,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
if (sc == null) {
finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.")
} else {
- registerAM(sc.ui.appUIHostPort, securityMgr)
+ registerAM(sc.ui.appUIAddress, securityMgr)
try {
userThread.join()
} finally {
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 02b9a81bf6b50..0b8744f4b8bdf 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy.yarn
import java.util.{List => JList}
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent._
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
@@ -32,6 +32,8 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+
object AllocationType extends Enumeration {
type AllocationType = Value
val HOST, RACK, ANY = Value
@@ -95,6 +97,14 @@ private[yarn] abstract class YarnAllocator(
protected val (preferredHostToCount, preferredRackToCount) =
generateNodeToWeight(conf, preferredNodes)
+ private val launcherPool = new ThreadPoolExecutor(
+ // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue
+ sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE,
+ 1, TimeUnit.MINUTES,
+ new LinkedBlockingQueue[Runnable](),
+ new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build())
+ launcherPool.allowCoreThreadTimeOut(true)
+
def getNumExecutorsRunning: Int = numExecutorsRunning.intValue
def getNumExecutorsFailed: Int = numExecutorsFailed.intValue
@@ -283,7 +293,7 @@ private[yarn] abstract class YarnAllocator(
executorMemory,
executorCores,
securityMgr)
- new Thread(executorRunnable).start()
+ launcherPool.execute(executorRunnable)
}
}
logDebug("""
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 3faaf053634d6..7fcd7ee0d4547 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml
diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml
index b6c8456d06684..fd934b7726181 100644
--- a/yarn/stable/pom.xml
+++ b/yarn/stable/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
yarn-parent_2.10
- 1.1.0-SNAPSHOT
+ 1.2.0-SNAPSHOT
../pom.xml