diff --git a/dev/run-tests b/dev/run-tests index c3d8f49cdd993..4be2baaf48cd1 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -24,6 +24,16 @@ cd "$FWDIR" # Remove work directory rm -rf ./work +source "$FWDIR/dev/run-tests-codes.sh" + +CURRENT_BLOCK=$BLOCK_GENERAL + +function handle_error () { + echo "[error] Got a return code of $? on line $1 of the run-tests script." + exit $CURRENT_BLOCK +} + + # Build against the right verison of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then @@ -91,26 +101,34 @@ if [ -n "$AMPLAB_JENKINS" ]; then fi fi -# Fail fast -set -e set -o pipefail +trap 'handle_error $LINENO' ERR echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_RAT + ./dev/check-license echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SCALA_STYLE + ./dev/lint-scala echo "" echo "=========================================================================" echo "Running Python style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYTHON_STYLE + ./dev/lint-python echo "" @@ -118,6 +136,8 @@ echo "=========================================================================" echo "Building Spark" echo "=========================================================================" +CURRENT_BLOCK=$BLOCK_BUILD + { # We always build with Hive because the PySpark Spark SQL tests need it. BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" @@ -141,6 +161,8 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" +CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS + { # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. # This must be a single argument, as it is. @@ -175,10 +197,16 @@ echo "" echo "=========================================================================" echo "Running PySpark tests" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS + ./python/run-tests echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_MIMA + ./dev/mima diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh new file mode 100644 index 0000000000000..1348e0609dda4 --- /dev/null +++ b/dev/run-tests-codes.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +readonly BLOCK_GENERAL=10 +readonly BLOCK_RAT=11 +readonly BLOCK_SCALA_STYLE=12 +readonly BLOCK_PYTHON_STYLE=13 +readonly BLOCK_BUILD=14 +readonly BLOCK_SPARK_UNIT_TESTS=15 +readonly BLOCK_PYSPARK_UNIT_TESTS=16 +readonly BLOCK_MIMA=17 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 0b1e31b9413cf..451f3b771cc76 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -26,9 +26,23 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" +source "$FWDIR/dev/run-tests-codes.sh" + COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" +# Important Environment Variables +# --- +# $ghprbActualCommit +#+ This is the hash of the most recent commit in the PR. +#+ The merge-base of this and master is the commit from which the PR was branched. +# $sha1 +#+ If the patch merges cleanly, this is a reference to the merge commit hash +#+ (e.g. "origin/pr/2606/merge"). +#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit. +#+ The merge-base of this and master in the case of a clean merge is the most recent commit +#+ against master. + COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" @@ -84,42 +98,46 @@ function post_message () { fi } + +# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR +#+ and not anything else added to master since the PR was branched. + # check PR merge-ability and check for new public classes { if [ "$sha1" == "$ghprbActualCommit" ]; then - merge_note=" * This patch **does not** merge cleanly!" + merge_note=" * This patch **does not merge cleanly**." else merge_note=" * This patch merges cleanly." + fi + + source_files=$( + git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + | grep -v -e "\/test" `# ignore files in test directories` \ + | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ + | tr "\n" " " + ) + new_public_classes=$( + git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` + ) - source_files=$( - git diff master... --name-only `# diff patch against master from branch point` \ - | grep -v -e "\/test" `# ignore files in test directories` \ - | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ - | tr "\n" " " - ) - new_public_classes=$( - git diff master... ${source_files} `# diff patch against master from branch point` \ - | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ - | grep -e "trait " -e "class " `# filter in lines with these key words` \ - | grep -e "{" -e "(" `# filter in lines with these key words, too` \ - | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ - | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ - | tr -d "\n" `# remove actual LF characters` - ) - - if [ "$new_public_classes" == "" ]; then - public_classes_note=" * This patch adds no public classes." - else - public_classes_note=" * This patch adds the following public classes _(experimental)_:" - public_classes_note="${public_classes_note}\n${new_public_classes}" - fi + if [ -z "$new_public_classes" ]; then + public_classes_note=" * This patch adds no public classes." + else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + public_classes_note="${public_classes_note}\n${new_public_classes}" fi } @@ -147,12 +165,30 @@ function post_message () { post_message "$fail_message" exit $test_result + elif [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes all tests**." else - if [ "$test_result" -eq "0" ]; then - test_result_note=" * This patch **passes** unit tests." + if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then + failing_test="some tests" + elif [ "$test_result" -eq "$BLOCK_RAT" ]; then + failing_test="RAT tests" + elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then + failing_test="Scala style tests" + elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then + failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then + failing_test="to build" + elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then + failing_test="Spark unit tests" + elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then + failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then + failing_test="MiMa tests" else - test_result_note=" * This patch **fails** unit tests." + failing_test="some tests" fi + + test_result_note=" * This patch **fails $failing_test**." fi } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index d0fe4179685ca..00dfc86c9e0bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -75,6 +75,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } + + override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept) } /** diff --git a/python/docs/modules.rst b/python/docs/modules.rst deleted file mode 100644 index 183564659fbcf..0000000000000 --- a/python/docs/modules.rst +++ /dev/null @@ -1,7 +0,0 @@ -. -= - -.. toctree:: - :maxdepth: 4 - - pyspark diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e9418320ff781..a45d79d6424c7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -410,6 +410,7 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. The mechanism is as follows: + 1. A Java RDD is created from the SequenceFile or other InputFormat, and the key and value Writable classes 2. Serialization is attempted via Pyrolite pickling diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index ac142fb49a90c..a765b1c4f7d87 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -89,11 +89,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @param regParam: The regularizer parameter (default: 1.0). @param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater + - "l2" for using SquaredL2Updater + - "none" for no regularizer + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features @@ -158,11 +161,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, @param initialWeights: The initial weights (default: None). @param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater + - "l2" for using SquaredL2Updater, + - "none" for no regularizer. + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index cbdbc09858013..54f34a98337ca 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -22,7 +22,7 @@ from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' +__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -66,6 +66,9 @@ def weights(self): def intercept(self): return self._intercept + def __repr__(self): + return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept) + class LinearRegressionModelBase(LinearModel): @@ -152,11 +155,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @param regParam: The regularizer parameter (default: 1.0). @param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater, + - "l2" for using SquaredL2Updater, + - "none" for no regularizer. + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f72e88ba6e2ba..5c20e100e144f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -32,7 +32,7 @@ from pyspark.serializers import PickleSerializer from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint -from pyspark.tests import PySparkTestCase +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase _have_scipy = False diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index afdcdbdf3ae01..5d7abfb96b7fe 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -48,6 +48,7 @@ def __del__(self): def predict(self, x): """ Predict the label of one or more examples. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index dc6497772e502..e77669aad76b6 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1208,6 +1208,7 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file system, using the L{org.apache.hadoop.io.Writable} types that we convert from the RDD's key and value types. The mechanism is as follows: + 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. 2. Keys and values of this Java RDD are converted to Writables and written out. diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2672da36c1f50..099fa54cf2bd7 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -211,7 +211,7 @@ def __eq__(self, other): return (isinstance(other, BatchedSerializer) and other.serializer == self.serializer) - def __str__(self): + def __repr__(self): return "BatchedSerializer<%s>" % str(self.serializer) @@ -279,7 +279,7 @@ def __eq__(self, other): return (isinstance(other, CartesianDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): + def __repr__(self): return "CartesianDeserializer<%s, %s>" % \ (str(self.key_ser), str(self.val_ser)) @@ -306,7 +306,7 @@ def __eq__(self, other): return (isinstance(other, PairDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): + def __repr__(self): return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index ce597cbe91e15..d57a802e4734a 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -396,7 +396,6 @@ def _external_items(self): for v in self.data.iteritems(): yield v self.data.clear() - gc.collect() # remove the merged partition for j in range(self.spills): @@ -428,7 +427,7 @@ def _recursive_merged_items(self, start): subdirs = [os.path.join(d, "parts", str(i)) for d in self.localdirs] m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) + subdirs, self.scale * self.partitions, self.partitions) m.pdata = [{} for _ in range(self.partitions)] limit = self._next_limit() @@ -486,7 +485,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch = 10 + batch = 100 chunks, current_chunk = [], [] iterator = iter(iterator) while True: diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 974b5e287bc00..114644ab8b79d 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -201,7 +201,7 @@ def __init__(self, elementType, containsNull=True): self.elementType = elementType self.containsNull = containsNull - def __str__(self): + def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6fb6bc998c752..7f05d48ade2b3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -67,10 +67,10 @@ SPARK_HOME = os.environ["SPARK_HOME"] -class TestMerger(unittest.TestCase): +class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 16 + self.N = 1 << 14 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -115,7 +115,7 @@ def test_medium_dataset(self): sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), @@ -123,7 +123,7 @@ def test_huge_dataset(self): m._cleanup() -class TestSorter(unittest.TestCase): +class SorterTests(unittest.TestCase): def test_in_memory_sort(self): l = range(1024) random.shuffle(l) @@ -244,16 +244,25 @@ def tearDown(self): sys.path = self._old_sys_path -class TestCheckpoint(PySparkTestCase): +class ReusedPySparkTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + +class CheckpointTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.checkpointDir.name) self.sc.setCheckpointDir(self.checkpointDir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): @@ -288,7 +297,7 @@ def test_checkpoint_and_restore(self): self.assertEquals([1, 2, 3, 4], recovered.collect()) -class TestAddFile(PySparkTestCase): +class AddFileTests(PySparkTestCase): def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -354,7 +363,7 @@ def func(x): self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) -class TestRDDFunctions(PySparkTestCase): +class RDDTests(ReusedPySparkTestCase): def test_id(self): rdd = self.sc.parallelize(range(10)) @@ -365,12 +374,6 @@ def test_id(self): self.assertEqual(id + 1, id2) self.assertEqual(id2, rdd2.id()) - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.sc.stop() - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - self.sc = SparkContext("local") - def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -636,7 +639,7 @@ def test_distinct(self): self.assertEquals(result.count(), 3) -class TestProfiler(PySparkTestCase): +class ProfilerTests(PySparkTestCase): def setUp(self): self._old_sys_path = list(sys.path) @@ -666,10 +669,9 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) -class TestSQL(PySparkTestCase): +class SQLTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.sqlCtx = SQLContext(self.sc) def test_udf(self): @@ -754,27 +756,19 @@ def test_serialize_nested_array_and_map(self): self.assertEqual("2", row.d) -class TestIO(PySparkTestCase): - - def test_stdout_redirection(self): - import subprocess - - def func(x): - subprocess.check_call('ls', shell=True) - self.sc.parallelize([1]).foreach(func) +class InputFormatTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) -class TestInputFormat(PySparkTestCase): - - def setUp(self): - PySparkTestCase.setUp(self) - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc) - - def tearDown(self): - PySparkTestCase.tearDown(self) - shutil.rmtree(self.tempdir.name) + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) def test_sequencefiles(self): basepath = self.tempdir.name @@ -954,15 +948,13 @@ def test_converters(self): self.assertEqual(maps, em) -class TestOutputFormat(PySparkTestCase): +class OutputFormatTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.tempdir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.tempdir.name, ignore_errors=True) def test_sequencefiles(self): @@ -1243,8 +1235,7 @@ def test_malformed_RDD(self): basepath + "/malformed/sequence")) -class TestDaemon(unittest.TestCase): - +class DaemonTests(unittest.TestCase): def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) @@ -1290,7 +1281,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class TestWorker(PySparkTestCase): +class WorkerTests(PySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) @@ -1342,11 +1333,6 @@ def run(): rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) - def test_fd_leak(self): - N = 1100 # fd limit is 1024 by default - rdd = self.sc.parallelize(range(N), N) - self.assertEquals(N, rdd.count()) - def test_after_exception(self): def raise_exception(_): raise Exception() @@ -1379,7 +1365,7 @@ def test_accumulator_when_reuse_worker(self): self.assertEqual(sum(range(100)), acc1.value) -class TestSparkSubmit(unittest.TestCase): +class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() @@ -1492,6 +1478,8 @@ def test_single_script_on_cluster(self): |sc = SparkContext() |print sc.parallelize([1, 2, 3]).map(foo).collect() """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf proc = subprocess.Popen( [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) @@ -1500,7 +1488,11 @@ def test_single_script_on_cluster(self): self.assertIn("[2, 4, 6]", out) -class ContextStopTests(unittest.TestCase): +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) def test_stop(self): sc = SparkContext() diff --git a/python/run-tests b/python/run-tests index a7ec270c7da21..c713861eb77bb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -34,7 +34,7 @@ rm -rf metastore warehouse function run_test() { echo "Running test: $1" - SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -48,6 +48,37 @@ function run_test() { fi } +function run_core_tests() { + echo "Run core tests ..." + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py" + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" +} + +function run_sql_tests() { + echo "Run sql tests ..." + run_test "pyspark/sql.py" +} + +function run_mllib_tests() { + echo "Run mllib tests ..." + run_test "pyspark/mllib/classification.py" + run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/linalg.py" + run_test "pyspark/mllib/random.py" + run_test "pyspark/mllib/recommendation.py" + run_test "pyspark/mllib/regression.py" + run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/tree.py" + run_test "pyspark/mllib/util.py" + run_test "pyspark/mllib/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -60,29 +91,9 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_test "pyspark/rdd.py" -run_test "pyspark/context.py" -run_test "pyspark/conf.py" -run_test "pyspark/sql.py" -# These tests are included in the module-level docs, and so must -# be handled on a higher level rather than within the python file. -export PYSPARK_DOC_TEST=1 -run_test "pyspark/broadcast.py" -run_test "pyspark/accumulators.py" -run_test "pyspark/serializers.py" -unset PYSPARK_DOC_TEST -run_test "pyspark/shuffle.py" -run_test "pyspark/tests.py" -run_test "pyspark/mllib/classification.py" -run_test "pyspark/mllib/clustering.py" -run_test "pyspark/mllib/linalg.py" -run_test "pyspark/mllib/random.py" -run_test "pyspark/mllib/recommendation.py" -run_test "pyspark/mllib/regression.py" -run_test "pyspark/mllib/stat.py" -run_test "pyspark/mllib/tests.py" -run_test "pyspark/mllib/tree.py" -run_test "pyspark/mllib/util.py" +run_core_tests +run_sql_tests +run_mllib_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -90,19 +101,8 @@ if [ $(which pypy) ]; then echo "Testing with PyPy version:" $PYSPARK_PYTHON --version - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - run_test "pyspark/sql.py" - # These tests are included in the module-level docs, and so must - # be handled on a higher level rather than within the python file. - export PYSPARK_DOC_TEST=1 - run_test "pyspark/broadcast.py" - run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - unset PYSPARK_DOC_TEST - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_core_tests + run_sql_tests fi if [[ $FAILED == 0 ]]; then diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index 9bd1719cb1808..7faf55bc63372 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -40,6 +40,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC private var rpc: YarnRPC = null private var resourceManager: AMRMProtocol = _ private var uiHistoryAddress: String = _ + private var registered: Boolean = false override def register( conf: YarnConfiguration, @@ -51,8 +52,11 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC this.rpc = YarnRPC.create(conf) this.uiHistoryAddress = uiHistoryAddress - resourceManager = registerWithResourceManager(conf) - registerApplicationMaster(uiAddress) + synchronized { + resourceManager = registerWithResourceManager(conf) + registerApplicationMaster(uiAddress) + registered = true + } new YarnAllocationHandler(conf, sparkConf, resourceManager, getAttemptId(), args, preferredNodeLocations, securityMgr) @@ -66,14 +70,16 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC appAttemptId } - override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = { - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(getAttemptId()) - finishReq.setFinishApplicationStatus(status) - finishReq.setDiagnostics(diagnostics) - finishReq.setTrackingUrl(uiHistoryAddress) - resourceManager.finishApplicationMaster(finishReq) + override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized { + if (registered) { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(getAttemptId()) + finishReq.setFinishApplicationStatus(status) + finishReq.setDiagnostics(diagnostics) + finishReq.setTrackingUrl(uiHistoryAddress) + resourceManager.finishApplicationMaster(finishReq) + } } override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index caceef5d4b5b0..a3c43b43848d2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -56,8 +57,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + @volatile private var exitCode = 0 + @volatile private var unregistered = false @volatile private var finished = false @volatile private var finalStatus = FinalApplicationStatus.UNDEFINED + @volatile private var finalMsg: String = "" @volatile private var userClassThread: Thread = _ private var reporterThread: Thread = _ @@ -71,80 +75,107 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private val sparkContextRef = new AtomicReference[SparkContext](null) final def run(): Int = { - val appAttemptId = client.getAttemptId() + try { + val appAttemptId = client.getAttemptId() - if (isDriver) { - // Set the web ui port to be ephemeral for yarn so we don't conflict with - // other spark processes running on the same box - System.setProperty("spark.ui.port", "0") + if (isDriver) { + // Set the web ui port to be ephemeral for yarn so we don't conflict with + // other spark processes running on the same box + System.setProperty("spark.ui.port", "0") - // Set the master property to match the requested mode. - System.setProperty("spark.master", "yarn-cluster") + // Set the master property to match the requested mode. + System.setProperty("spark.master", "yarn-cluster") - // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. - System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) - } + // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + } - logInfo("ApplicationAttemptId: " + appAttemptId) + logInfo("ApplicationAttemptId: " + appAttemptId) - val cleanupHook = new Runnable { - override def run() { - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - finish(FinalApplicationStatus.SUCCEEDED) - } + val cleanupHook = new Runnable { + override def run() { + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + } + val maxAppAttempts = client.getMaxRegAttempts(yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + + if (!finished) { + // this shouldn't ever happen, but if it does assume weird failure + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, + "shutdown hook called without cleanly finishing") + } - // Cleanup the staging dir after the app is finished, or if it's the last attempt at - // running the AM. - val maxAppAttempts = client.getMaxRegAttempts(yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts - if (finished || isLastAttempt) { - cleanupStagingDir() + if (!unregistered) { + // we only want to unregister if we don't want the RM to retry + if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + unregister(finalStatus, finalMsg) + cleanupStagingDir() + } + } } } - } - // Use higher priority than FileSystem. - assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) - ShutdownHookManager - .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) + // Use higher priority than FileSystem. + assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) + ShutdownHookManager + .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserClass which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) + // Call this to force generation of secret so it gets populated into the + // Hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the executor containers. + val securityMgr = new SecurityManager(sparkConf) - if (isDriver) { - runDriver(securityMgr) - } else { - runExecutorLauncher(securityMgr) + if (isDriver) { + runDriver(securityMgr) + } else { + runExecutorLauncher(securityMgr) + } + } catch { + case e: Exception => + // catch everything else if not specifically handled + logError("Uncaught exception: ", e) + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, + "Uncaught exception: " + e.getMessage()) } + exitCode + } - if (finalStatus != FinalApplicationStatus.UNDEFINED) { - finish(finalStatus) - 0 - } else { - 1 + /** + * unregister is used to completely unregister the application from the ResourceManager. + * This means the ResourceManager will not retry the application attempt on your behalf if + * a failure occurred. + */ + final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) } } - final def finish(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { if (!finished) { - logInfo(s"Finishing ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - finished = true + logInfo(s"Final app status: ${status}, exitCode: ${code}" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code finalStatus = status - try { - if (Thread.currentThread() != reporterThread) { - reporterThread.interrupt() - reporterThread.join() - } - } finally { - client.shutdown(status, Option(diagnostics).getOrElse("")) + finalMsg = msg + finished = true + if (Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() } } } @@ -182,7 +213,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter() - val userThread = startUserClass() + setupSystemSecurityManager() + userClassThread = startUserClass() // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. @@ -190,15 +222,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // If there is no SparkContext at this point, just fail the app. if (sc == null) { - finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.") + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SC_NOT_INITED, + "Timed out waiting for SparkContext.") } else { registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) - try { - userThread.join() - } finally { - // In cluster mode, ask the reporter thread to stop since the user app is finished. - reporterThread.interrupt() - } + userClassThread.join() } } @@ -211,7 +240,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // In client mode the actor will stop the reporter thread. reporterThread.join() - finalStatus = FinalApplicationStatus.SUCCEEDED } private def launchReporterThread(): Thread = { @@ -231,33 +259,26 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val t = new Thread { override def run() { var failureCount = 0 - while (!finished) { try { - checkNumExecutorsFailed() - if (!finished) { + if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Max number of executor failures reached") + } else { logDebug("Sending progress") allocator.allocateResources() } failureCount = 0 } catch { + case i: InterruptedException => case e: Throwable => { failureCount += 1 if (!NonFatal(e) || failureCount >= reporterMaxFailures) { - logError("Exception was thrown from Reporter thread.", e) - finish(FinalApplicationStatus.FAILED, "Exception was thrown" + - s"${failureCount} time(s) from Reporter thread.") - - /** - * If exception is thrown from ReporterThread, - * interrupt user class to stop. - * Without this interrupting, if exception is - * thrown before allocating enough executors, - * YarnClusterScheduler waits until timeout even though - * we cannot allocate executors. - */ - logInfo("Interrupting user class to stop.") - userClassThread.interrupt + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + + s"${failureCount} time(s) from Reporter thread.") + } else { logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) } @@ -308,7 +329,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, sparkContextRef.synchronized { var count = 0 val waitTime = 10000L - val numTries = sparkConf.getInt("spark.yarn.ApplicationMaster.waitTries", 10) + val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 10) while (sparkContextRef.get() == null && count < numTries && !finished) { logInfo("Waiting for spark context initialization ... " + count) count = count + 1 @@ -328,10 +349,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def waitForSparkDriver(): ActorRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false + var count = 0 val hostport = args.userArgs(0) val (driverHost, driverPort) = Utils.parseHostPort(hostport) - while (!driverUp) { + + // spark driver should already be up since it launched us, but we don't want to + // wait forever, so wait 100 seconds max to match the cluster mode setting. + // Leave this config unpublished for now. SPARK-3779 to investigating changing + // this config to be time based. + val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 1000) + + while (!driverUp && !finished && count < numTries) { try { + count = count + 1 val socket = new Socket(driverHost, driverPort) socket.close() logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) @@ -343,6 +373,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, Thread.sleep(100) } } + + if (!driverUp) { + throw new SparkException("Failed to connect to driver!") + } + sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) @@ -354,18 +389,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") } - private def checkNumExecutorsFailed() = { - if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finish(FinalApplicationStatus.FAILED, "Max number of executor failures reached.") - - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from checkNumExecutorsFailed") - sc.stop() - } - } - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter() = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) @@ -379,40 +402,81 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } + /** + * This system security manager applies to the entire process. + * It's main purpose is to handle the case if the user code does a System.exit. + * This allows us to catch that and properly set the YARN application status and + * cleanup if needed. + */ + private def setupSystemSecurityManager(): Unit = { + try { + var stopped = false + System.setSecurityManager(new java.lang.SecurityManager() { + override def checkExit(paramInt: Int) { + if (!stopped) { + logInfo("In securityManager checkExit, exit code: " + paramInt) + if (paramInt == 0) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } else { + finish(FinalApplicationStatus.FAILED, + paramInt, + "User class exited with non-zero exit code") + } + stopped = true + } + } + // required for the checkExit to work properly + override def checkPermission(perm: java.security.Permission): Unit = {} + }) + } + catch { + case e: SecurityException => + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SECURITY, + "Error in setSecurityManager") + logError("Error in setSecurityManager:", e) + } + } + + /** + * Start the user class, which contains the spark driver, in a separate Thread. + * If the main routine exits cleanly or exits with System.exit(0) we + * assume it was successful, for all other cases we assume failure. + * + * Returns the user thread that was started. + */ private def startUserClass(): Thread = { logInfo("Starting the user JAR in a separate Thread") System.setProperty("spark.executor.instances", args.numExecutors.toString) val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) - userClassThread = new Thread { + val userThread = new Thread { override def run() { - var status = FinalApplicationStatus.FAILED try { - // Copy val mainArgs = new Array[String](args.userArgs.size) args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) mainMethod.invoke(null, mainArgs) - // Some apps have "System.exit(0)" at the end. The user thread will stop here unless - // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. - status = FinalApplicationStatus.SUCCEEDED + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running users class") } catch { case e: InvocationTargetException => e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class - - case e => throw e + case e: Exception => + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, + "User class threw exception: " + e.getMessage) + // re-throw to get it logged + throw e } - } finally { - logDebug("Finishing main") - finalStatus = status } } } - userClassThread.setName("Driver") - userClassThread.start() - userClassThread + userThread.setName("Driver") + userThread.start() + userThread } // Actor used to monitor the driver when running in client deploy mode. @@ -432,7 +496,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, override def receive = { case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") - finish(FinalApplicationStatus.SUCCEEDED) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") driver ! x @@ -446,6 +510,15 @@ object ApplicationMaster extends Logging { val SHUTDOWN_HOOK_PRIORITY: Int = 30 + // exit codes for different causes, no reason behind the values + private val EXIT_SUCCESS = 0 + private val EXIT_UNCAUGHT_EXCEPTION = 10 + private val EXIT_MAX_EXECUTOR_FAILURES = 11 + private val EXIT_REPORTER_FAILURE = 12 + private val EXIT_SC_NOT_INITED = 13 + private val EXIT_SECURITY = 14 + private val EXIT_EXCEPTION_USER_CLASS = 15 + private var master: ApplicationMaster = _ def main(args: Array[String]) = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 943dc56202a37..2510b9c9cef68 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -49,12 +49,12 @@ trait YarnRMClient { securityMgr: SecurityManager): YarnAllocator /** - * Shuts down the AM. Guaranteed to only be called once. + * Unregister the AM. Guaranteed to only be called once. * * @param status The final status of the AM. * @param diagnostics Diagnostics message to include in the final status. */ - def shutdown(status: FinalApplicationStatus, diagnostics: String = ""): Unit + def unregister(status: FinalApplicationStatus, diagnostics: String = ""): Unit /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index b581790e158ac..8d4b96ed79933 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -45,6 +45,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC private var amClient: AMRMClient[ContainerRequest] = _ private var uiHistoryAddress: String = _ + private var registered: Boolean = false override def register( conf: YarnConfiguration, @@ -59,13 +60,19 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC this.uiHistoryAddress = uiHistoryAddress logInfo("Registering the ApplicationMaster") - amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + synchronized { + amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + registered = true + } new YarnAllocationHandler(conf, sparkConf, amClient, getAttemptId(), args, preferredNodeLocations, securityMgr) } - override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = - amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) + override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized { + if (registered) { + amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) + } + } override def getAttemptId() = { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())