From 61e21fe7f478e7b06b72851f26b87d99cbbdf117 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 16 Sep 2014 09:18:03 -0700 Subject: [PATCH 001/315] SPARK-3069 [DOCS] Build instructions in README are outdated Here's my crack at Bertrand's suggestion. The Github `README.md` contains build info that's outdated. It should just point to the current online docs, and reflect that Maven is the primary build now. (Incidentally, the stanza at the end about contributions of original work should go in https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark too. It won't hurt to be crystal clear about the agreement to license, given that ICLAs are not required of anyone here.) Author: Sean Owen Closes #2014 from srowen/SPARK-3069 and squashes the following commits: 501507e [Sean Owen] Note that Zinc is for Maven builds too db2bd97 [Sean Owen] sbt -> sbt/sbt and add note about zinc be82027 [Sean Owen] Fix additional occurrences of building-with-maven -> building-spark 91c921f [Sean Owen] Move building-with-maven to building-spark and create a redirect. Update doc links to building-spark.html Add jekyll-redirect-from plugin and make associated config changes (including fixing pygments deprecation). Add example of SBT to README.md 999544e [Sean Owen] Change "Building Spark with Maven" title to "Building Spark"; reinstate tl;dr info about dev/run-tests in README.md; add brief note about building with SBT c18d140 [Sean Owen] Optionally, remove the copy of contributing text from main README.md 8e83934 [Sean Owen] Add CONTRIBUTING.md to trigger notice on new pull request page b1c04a1 [Sean Owen] Refer to current online documentation for building, and remove slightly outdated copy in README.md --- CONTRIBUTING.md | 12 +++ README.md | 78 ++++--------------- docs/README.md | 5 +- docs/_config.yml | 4 +- docs/_layouts/global.html | 2 +- ...ilding-with-maven.md => building-spark.md} | 20 ++++- docs/hadoop-third-party-distributions.md | 2 +- docs/index.md | 4 +- docs/running-on-yarn.md | 2 +- docs/streaming-kinesis-integration.md | 2 +- make-distribution.sh | 2 +- 11 files changed, 60 insertions(+), 73 deletions(-) create mode 100644 CONTRIBUTING.md rename docs/{building-with-maven.md => building-spark.md} (87%) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000..c6b4aa5344757 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,12 @@ +## Contributing to Spark + +Contributions via GitHub pull requests are gladly accepted from their original +author. Along with any pull requests, please state that the contribution is +your original work and that you license the work to the project under the +project's open source license. Whether or not you state this explicitly, by +submitting any copyrighted material via pull request, email, or other means +you agree to license the material under the project's open source license and +warrant that you have the legal authority to do so. + +Please see [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) +for more information. diff --git a/README.md b/README.md index 5b09ad86849e7..b05bbfb5a594c 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,19 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the project webpage at . +guide, on the [project web page](http://spark.apache.org/documentation.html). This README file only contains basic setup instructions. ## Building Spark -Spark is built on Scala 2.10. To build Spark and its example programs, run: +Spark is built using [Apache Maven](http://maven.apache.org/). +To build Spark and its example programs, run: - ./sbt/sbt assembly + mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) +More detailed documentation is available from the project site, at +["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). ## Interactive Scala Shell @@ -71,73 +74,24 @@ can be run using: ./dev/run-tests +Please see the guidance on how to +[run all automated tests](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-AutomatedTesting) + ## A Note About Hadoop Versions Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. -You can change the version by setting `-Dhadoop.version` when building Spark. - -For Apache Hadoop versions 1.x, Cloudera CDH MRv1, and other Hadoop -versions without YARN, use: - - # Apache Hadoop 1.2.1 - $ sbt/sbt -Dhadoop.version=1.2.1 assembly - - # Cloudera CDH 4.2.0 with MapReduce v1 - $ sbt/sbt -Dhadoop.version=2.0.0-mr1-cdh4.2.0 assembly - -For Apache Hadoop 2.2.X, 2.1.X, 2.0.X, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions -with YARN, also set `-Pyarn`: - - # Apache Hadoop 2.0.5-alpha - $ sbt/sbt -Dhadoop.version=2.0.5-alpha -Pyarn assembly - - # Cloudera CDH 4.2.0 with MapReduce v2 - $ sbt/sbt -Dhadoop.version=2.0.0-cdh4.2.0 -Pyarn assembly - - # Apache Hadoop 2.2.X and newer - $ sbt/sbt -Dhadoop.version=2.2.0 -Pyarn assembly - -When developing a Spark application, specify the Hadoop version by adding the -"hadoop-client" artifact to your project's dependencies. For example, if you're -using Hadoop 1.2.1 and build your application using SBT, add this entry to -`libraryDependencies`: - - "org.apache.hadoop" % "hadoop-client" % "1.2.1" -If your project is built with Maven, add this to your POM file's `` section: - - - org.apache.hadoop - hadoop-client - 1.2.1 - - - -## A Note About Thrift JDBC server and CLI for Spark SQL - -Spark SQL supports Thrift JDBC server and CLI. -See sql-programming-guide.md for more information about using the JDBC server and CLI. -You can use those features by setting `-Phive` when building Spark as follows. - - $ sbt/sbt -Phive assembly +Please refer to the build documentation at +["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) +for detailed guidance on building for a particular distribution of Hadoop, including +building for particular Hive and Hive Thriftserver distributions. See also +["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) +for guidance on building a Spark application that works with a particular +distribution. ## Configuration Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. - - -## Contributing to Spark - -Contributions via GitHub pull requests are gladly accepted from their original -author. Along with any pull requests, please state that the contribution is -your original work and that you license the work to the project under the -project's open source license. Whether or not you state this explicitly, by -submitting any copyrighted material via pull request, email, or other means -you agree to license the material under the project's open source license and -warrant that you have the legal authority to do so. - -Please see [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) -for more information. diff --git a/docs/README.md b/docs/README.md index 0a0126c5747d1..fdc89d2eb767a 100644 --- a/docs/README.md +++ b/docs/README.md @@ -23,8 +23,9 @@ The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllr To use the `jekyll` command, you will need to have Jekyll installed. The easiest way to do this is via a Ruby Gem, see the [jekyll installation instructions](http://jekyllrb.com/docs/installation). -If not already installed, you need to install `kramdown` with `sudo gem install kramdown`. -Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory +If not already installed, you need to install `kramdown` and `jekyll-redirect-from` Gems +with `sudo gem install kramdown jekyll-redirect-from`. +Execute `jekyll build` from the `docs/` directory. Compiling the site with Jekyll will create a directory called `_site` containing index.html as well as the rest of the compiled files. You can modify the default Jekyll build as follows: diff --git a/docs/_config.yml b/docs/_config.yml index 45b78fe724a50..d3ea2625c7448 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -1,5 +1,7 @@ -pygments: true +highlighter: pygments markdown: kramdown +gems: + - jekyll-redirect-from # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index b30ab1e5218c0..a53e8a775b71f 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -109,7 +109,7 @@
  • Hardware Provisioning
  • 3rd-Party Hadoop Distros
  • -
  • Building Spark with Maven
  • +
  • Building Spark
  • Contributing to Spark
  • diff --git a/docs/building-with-maven.md b/docs/building-spark.md similarity index 87% rename from docs/building-with-maven.md rename to docs/building-spark.md index bce7412c7d4c9..2378092d4a1a8 100644 --- a/docs/building-with-maven.md +++ b/docs/building-spark.md @@ -1,6 +1,7 @@ --- layout: global -title: Building Spark with Maven +title: Building Spark +redirect_from: "building-with-maven.html" --- * This will become a table of contents (this text will be scraped). @@ -159,4 +160,21 @@ then ship it over to the cluster. We are investigating the exact cause for this. The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +# Building with SBT +Maven is the official recommendation for packaging Spark, and is the "build of reference". +But SBT is supported for day-to-day development since it can provide much faster iterative +compilation. More advanced developers may wish to use SBT. + +The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables +can be set to control the SBT build. For example: + + sbt/sbt -Pyarn -Phadoop-2.3 compile + +# Speeding up Compilation with Zinc + +[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental +compiler. When run locally as a background process, it speeds up builds of Scala-based projects +like Spark. Developers who regularly recompile Spark with Maven will be the most interested in +Zinc. The project site gives instructions for building and running `zinc`; OS X users can +install it using `brew install zinc`. \ No newline at end of file diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index ab1023b8f1842..dd73e9dc54440 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -11,7 +11,7 @@ with these distributions: When compiling Spark, you'll need to specify the Hadoop version by defining the `hadoop.version` property. For certain versions, you will need to specify additional profiles. For more detail, -see the guide on [building with maven](building-with-maven.html#specifying-the-hadoop-version): +see the guide on [building with maven](building-spark.html#specifying-the-hadoop-version): mvn -Dhadoop.version=1.0.4 -DskipTests clean package mvn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package diff --git a/docs/index.md b/docs/index.md index 7fe6b43d32af7..e8ebadbd4e427 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,7 +12,7 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. The downloads page contains Spark packages for many popular HDFS versions. If you'd like to build Spark from -scratch, visit [building Spark with Maven](building-with-maven.html). +scratch, visit [Building Spark](building-spark.html). Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy to run locally on one machine --- all you need is to have `java` installed on your system `PATH`, @@ -105,7 +105,7 @@ options for deployment: * [3rd Party Hadoop Distributions](hadoop-third-party-distributions.html): using common Hadoop distributions * Integration with other storage systems: * [OpenStack Swift](storage-openstack-swift.html) -* [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system +* [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) **External Resources:** diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 212248bcce1c1..74bcc2eeb65f6 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -11,7 +11,7 @@ was added to Spark in version 0.6.0, and improved in subsequent releases. Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. Binary distributions can be downloaded from the Spark project website. -To build Spark yourself, refer to the [building with Maven guide](building-with-maven.html). +To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index c6090d9ec30c7..379eb513d521e 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -108,7 +108,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download Spark source and follow the [instructions](building-with-maven.html) to build Spark with profile *-Pkinesis-asl*. +- Download Spark source and follow the [instructions](building-spark.html) to build Spark with profile *-Pkinesis-asl*. mvn -Pkinesis-asl -DskipTests clean package diff --git a/make-distribution.sh b/make-distribution.sh index 9b012b9222db4..884659954a491 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -40,7 +40,7 @@ function exit_with_usage { echo "" echo "usage:" echo "./make-distribution.sh [--name] [--tgz] [--with-tachyon] " - echo "See Spark's \"Building with Maven\" doc for correct Maven options." + echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" exit 1 } From 7b8008f5a4d413b61aa88fbc60959e98e59f17dd Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 16 Sep 2014 09:21:03 -0700 Subject: [PATCH 002/315] [SPARK-2182] Scalastyle rule blocking non ascii characters. ...erators. Author: Prashant Sharma Closes #2358 from ScrapCodes/scalastyle-unicode and squashes the following commits: 12a20f2 [Prashant Sharma] [SPARK-2182] Scalastyle rule blocking (non keyboard typeable) unicode operators. --- .../scalastyle/NonASCIICharacterChecker.scala | 39 +++++++++++++++++++ scalastyle-config.xml | 1 + 2 files changed, 40 insertions(+) create mode 100644 project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala new file mode 100644 index 0000000000000..3d43c35299555 --- /dev/null +++ b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.scalastyle + +import java.util.regex.Pattern + +import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} + +import scalariform.lexer.Token +import scalariform.parser.CompilationUnit + +class NonASCIICharacterChecker extends ScalariformChecker { + val errorKey: String = "non.ascii.character.disallowed" + + override def verify(ast: CompilationUnit): List[ScalastyleError] = { + ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList + } + + private def hasNonAsciiChars(x: Token) = + x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) + .matcher(x.text.trim).matches() + +} diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 76ba1ecca33ab..c54f8b72ebf42 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -140,5 +140,6 @@ + From 86d253ec4e2ed94c68687d575f9e2dfbb44463e1 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 16 Sep 2014 11:21:30 -0700 Subject: [PATCH 003/315] [SPARK-3527] [SQL] Strip the string message Author: Cheng Hao Closes #2392 from chenghao-intel/trim and squashes the following commits: e52024f [Cheng Hao] trim the string message --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c551c7c9877e8..7dbaf7faff0c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -414,7 +414,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def simpleString: String = s"""== Physical Plan == |${stringOrError(executedPlan)} - """ + """.stripMargin.trim override def toString: String = // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)}) From 9d5fa763d8559ac412a18d7a2f43c4368a0af897 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 16 Sep 2014 11:39:57 -0700 Subject: [PATCH 004/315] [SPARK-3519] add distinct(n) to PySpark Added missing rdd.distinct(numPartitions) and associated tests Author: Matthew Farrellee Closes #2383 from mattf/SPARK-3519 and squashes the following commits: 30b837a [Matthew Farrellee] Combine test cases to save on JVM startups 6bc4a2c [Matthew Farrellee] [SPARK-3519] add distinct(n) to SchemaRDD in PySpark 7a17f2b [Matthew Farrellee] [SPARK-3519] add distinct(n) to PySpark --- python/pyspark/rdd.py | 4 ++-- python/pyspark/sql.py | 7 +++++-- python/pyspark/tests.py | 17 +++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 21f182b0ff137..cb09c191bed71 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -301,7 +301,7 @@ def func(iterator): return ifilter(f, iterator) return self.mapPartitions(func, True) - def distinct(self): + def distinct(self, numPartitions=None): """ Return a new RDD containing the distinct elements in this RDD. @@ -309,7 +309,7 @@ def distinct(self): [1, 2, 3] """ return self.map(lambda x: (x, None)) \ - .reduceByKey(lambda x, _: x) \ + .reduceByKey(lambda x, _: x, numPartitions) \ .map(lambda (x, _): x) def sample(self, withReplacement, fraction, seed=None): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index fc9310fef318c..eac55cbe15193 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1694,8 +1694,11 @@ def coalesce(self, numPartitions, shuffle=False): rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) return SchemaRDD(rdd, self.sql_ctx) - def distinct(self): - rdd = self._jschema_rdd.distinct() + def distinct(self, numPartitions=None): + if numPartitions is None: + rdd = self._jschema_rdd.distinct() + else: + rdd = self._jschema_rdd.distinct(numPartitions) return SchemaRDD(rdd, self.sql_ctx) def intersection(self, other): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f255b44359fec..0b3854347ad2e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -587,6 +587,14 @@ def test_repartitionAndSortWithinPartitions(self): self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_distinct(self): + rdd = self.sc.parallelize((1, 2, 3)*10, 10) + self.assertEquals(rdd.getNumPartitions(), 10) + self.assertEquals(rdd.distinct().count(), 3) + result = rdd.distinct(5) + self.assertEquals(result.getNumPartitions(), 5) + self.assertEquals(result.count(), 3) + class TestSQL(PySparkTestCase): @@ -636,6 +644,15 @@ def test_basic_functions(self): srdd.count() srdd.collect() + def test_distinct(self): + rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) + srdd = self.sqlCtx.jsonRDD(rdd) + self.assertEquals(srdd.getNumPartitions(), 10) + self.assertEquals(srdd.distinct().count(), 3) + result = srdd.distinct(5) + self.assertEquals(result.getNumPartitions(), 5) + self.assertEquals(result.count(), 3) + class TestIO(PySparkTestCase): From 7583699873fb4f252c6ce65db1096783ef438731 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 16 Sep 2014 11:40:28 -0700 Subject: [PATCH 005/315] [SPARK-3308][SQL] Ability to read JSON Arrays as tables This PR aims to support reading top level JSON arrays and take every element in such an array as a row (an empty array will not generate a row). JIRA: https://issues.apache.org/jira/browse/SPARK-3308 Author: Yin Huai Closes #2400 from yhuai/SPARK-3308 and squashes the following commits: 990077a [Yin Huai] Handle top level JSON arrays. --- .../org/apache/spark/sql/json/JsonRDD.scala | 10 +++++++--- .../org/apache/spark/sql/json/JsonSuite.scala | 17 +++++++++++++++++ .../apache/spark/sql/json/TestJsonData.scala | 7 +++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 873221835daf8..0f27fd13e7379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -287,9 +287,13 @@ private[sql] object JsonRDD extends Logging { // the ObjectMapper will take the last value associated with this duplicate key. // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() - iter.map { record => - val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) - parsed.asInstanceOf[Map[String, Any]] + iter.flatMap { record => + val parsed = mapper.readValue(record, classOf[Object]) match { + case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil + case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + } + + parsed } }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index b50d93855405a..685e788207725 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -622,4 +622,21 @@ class JsonSuite extends QueryTest { ("str1", Nil, "str4", 2) :: Nil ) } + + test("SPARK-3308 Read top level JSON arrays") { + val jsonSchemaRDD = jsonRDD(jsonArray) + jsonSchemaRDD.registerTempTable("jsonTable") + + checkAnswer( + sql( + """ + |select a, b, c + |from jsonTable + """.stripMargin), + ("str_a_1", null, null) :: + ("str_a_2", null, null) :: + (null, "str_b_3", null) :: + ("str_a_4", "str_b_4", "str_c_4") ::Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 5f0b3959a63ad..fc833b8b54e4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -136,4 +136,11 @@ object TestJsonData { ] ]] }""" :: Nil) + + val jsonArray = + TestSQLContext.sparkContext.parallelize( + """[{"a":"str_a_1"}]""" :: + """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """[]""" :: Nil) } From 30f288ae34a67307aa45b7aecbd0d02a0a14fe69 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 16 Sep 2014 11:42:26 -0700 Subject: [PATCH 006/315] [SPARK-2890][SQL] Allow reading of data when case insensitive resolution could cause possible ambiguity. Throwing an error in the constructor makes it possible to run queries, even when there is no actual ambiguity. Remove this check in favor of throwing an error in analysis when they query is actually is ambiguous. Also took the opportunity to add test cases that would have caught a subtle bug in my first attempt at fixing this and refactor some other test code. Author: Michael Armbrust Closes #2209 from marmbrus/sameNameStruct and squashes the following commits: 729cca4 [Michael Armbrust] Better tests. a003aeb [Michael Armbrust] Remove error (it'll be caught in analysis). --- .../spark/sql/catalyst/types/dataTypes.scala | 4 -- .../sql/hive/execution/HiveUdfSuite.scala | 67 ++++++++++++------- 2 files changed, 44 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 70c6d06cf2534..49520b7678e90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -308,13 +308,9 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) { object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - - private def validateFields(fields: Seq[StructField]): Boolean = - fields.map(field => field.name).distinct.size == fields.size } case class StructType(fields: Seq[StructField]) extends DataType { - require(StructType.validateFields(fields), "Found fields with the same name.") /** * Returns all field names in a [[Seq]]. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index b6b8592344ef5..cc125d539c3c2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -17,47 +17,68 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.TestHive -import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext._ +import java.io.{DataOutput, DataInput} import java.util -import org.apache.hadoop.fs.{FileSystem, Path} +import java.util.Properties + +import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe} -import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.io.Writable import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector} -import java.util.Properties + import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import scala.collection.JavaConversions._ -import java.io.{DataOutput, DataInput} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject +import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) + /** * A test suite for Hive custom UDFs. */ class HiveUdfSuite extends HiveComparisonTest { - TestHive.sql( - """ + test("spark sql udf test that returns a struct") { + registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) + assert(sql( + """ + |SELECT getStruct(1).f1, + | getStruct(1).f2, + | getStruct(1).f3, + | getStruct(1).f4, + | getStruct(1).f5 FROM src LIMIT 1 + """.stripMargin).first() === Row(1, 2, 3, 4, 5)) + } + + test("hive struct udf") { + sql( + """ |CREATE EXTERNAL TABLE hiveUdfTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) |ROW FORMAT SERDE '%s' |STORED AS SEQUENCEFILE - """.stripMargin.format(classOf[PairSerDe].getName) - ) - - TestHive.sql( - "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" - .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) - ) - - TestHive.sql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) - - TestHive.sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - - TestHive.sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + """. + stripMargin.format(classOf[PairSerDe].getName)) + + val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + sql(s""" + ALTER TABLE hiveUdfTestTable + ADD IF NOT EXISTS PARTITION(partition='testUdf') + LOCATION '$location'""") + + sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") + sql("SELECT testUdf(pair) FROM hiveUdfTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { From 8e7ae477ba40a064d27cf149aa211ff6108fe239 Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Tue, 16 Sep 2014 11:45:35 -0700 Subject: [PATCH 007/315] [SPARK-2314][SQL] Override collect and take in python library, and count in java library, with optimized versions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SchemaRDD overrides RDD functions, including collect, count, and take, with optimized versions making use of the query optimizer. The java and python interface classes wrapping SchemaRDD need to ensure the optimized versions are called as well. This patch overrides relevant calls in the python and java interfaces with optimized versions. Adds a new Row serialization pathway between python and java, based on JList[Array[Byte]] versus the existing RDD[Array[Byte]]. I wasn’t overjoyed about doing this, but I noticed that some QueryPlans implement optimizations in executeCollect(), which outputs an Array[Row] rather than the typical RDD[Row] that can be shipped to python using the existing serialization code. To me it made sense to ship the Array[Row] over to python directly instead of converting it back to an RDD[Row] just for the purpose of sending the Rows to python using the existing serialization code. Author: Aaron Staple Closes #1592 from staple/SPARK-2314 and squashes the following commits: 89ff550 [Aaron Staple] Merge with master. 6bb7b6c [Aaron Staple] Fix typo. b56d0ac [Aaron Staple] [SPARK-2314][SQL] Override count in JavaSchemaRDD, forwarding to SchemaRDD's count. 0fc9d40 [Aaron Staple] Fix comment typos. f03cdfa [Aaron Staple] [SPARK-2314][SQL] Override collect and take in sql.py, forwarding to SchemaRDD's collect. --- .../apache/spark/api/python/PythonRDD.scala | 2 +- python/pyspark/sql.py | 47 +++++++++++++++++-- .../org/apache/spark/sql/SchemaRDD.scala | 37 ++++++++++----- .../spark/sql/api/java/JavaSchemaRDD.scala | 2 + 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d5002fa02992b..12b345a8fa7c3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -776,7 +776,7 @@ private[spark] object PythonRDD extends Logging { } /** - * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark. */ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index eac55cbe15193..621a556ec6356 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -30,6 +30,7 @@ from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync from itertools import chain, ifilter, imap @@ -1550,6 +1551,18 @@ def id(self): self._id = self._jrdd.id() return self._id + def limit(self, num): + """Limit the result count to the number specified. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.limit(2).collect() + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + >>> srdd.limit(0).collect() + [] + """ + rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() + return SchemaRDD(rdd, self.sql_ctx) + def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. @@ -1626,15 +1639,39 @@ def count(self): return self._jschema_rdd.count() def collect(self): - """ - Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows in this RDD. - Each object in the list is on Row, the fields can be accessed as + Each object in the list is a Row, the fields can be accessed as attributes. + + Unlike the base RDD implementation of collect, this implementation + leverages the query optimizer to perform a collect on the SchemaRDD, + which supports features such as filter pushdown. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect() + [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - rows = RDD.collect(self) + with SCCallSiteSync(self.context) as css: + bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() cls = _create_cls(self.schema()) - return map(cls, rows) + return map(cls, self._collect_iterator_through_file(bytesInJava)) + + def take(self, num): + """Take the first num rows of the RDD. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + Unlike the base RDD implementation of take, this implementation + leverages the query optimizer to perform a collect on a SchemaRDD, + which supports features such as filter pushdown. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.take(2) + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + """ + return self.limit(num).collect() # Convert each object in the RDD to a Row with the right class # for this SchemaRDD, so that fields can be accessed as attributes. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d2ceb4a2b0b25..3bc5dce095511 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -377,15 +377,15 @@ class SchemaRDD( def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + * Helper for converting a Row to a simple Array suitable for pyspark serialization. */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + private def rowToJArray(row: Row, structType: StructType): Array[Any] = { import scala.collection.Map def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (obj: Row, struct: StructType) => rowToArray(obj, struct) + case (obj: Row, struct: StructType) => rowToJArray(obj, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -402,22 +402,37 @@ class SchemaRDD( case (other, _) => other } - def rowToArray(row: Row, structType: StructType): Array[Any] = { - val fields = structType.fields.map(field => field.dataType) - row.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray - } + val fields = structType.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray + } + /** + * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + */ + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToArray(row, rowSchema) + rowToJArray(row, rowSchema) }.grouped(100).map(batched => pickle.dumps(batched.toArray)) } } + /** + * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same + * format as javaToPython. It is used by pyspark. + */ + private[sql] def collectToPython: JList[Array[Byte]] = { + val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) + val pickle = new Pickler + new java.util.ArrayList(collect().map { row => + rowToJArray(row, rowSchema) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) + } + /** * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value * of base RDD functions that do not change schema. @@ -433,7 +448,7 @@ class SchemaRDD( } // ======================================================================= - // Overriden RDD actions + // Overridden RDD actions // ======================================================================= override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 4d799b4038fdd..e7faba0c7f620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -112,6 +112,8 @@ class JavaSchemaRDD( new java.util.ArrayList(arr) } + override def count(): Long = baseSchemaRDD.count + override def take(num: Int): JList[Row] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_)) From df90e81fd383c0d89dee6db16d5520def9190c56 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Tue, 16 Sep 2014 11:48:20 -0700 Subject: [PATCH 008/315] [Docs] minor punctuation fix Author: Nicholas Chammas Closes #2414 from nchammas/patch-1 and squashes the following commits: 14664bf [Nicholas Chammas] [Docs] minor punctuation fix --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b05bbfb5a594c..8dd8b70696aa2 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ can be run using: ./dev/run-tests Please see the guidance on how to -[run all automated tests](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-AutomatedTesting) +[run all automated tests](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-AutomatedTesting). ## A Note About Hadoop Versions From 84073eb1172dc959936149265378f6e24d303685 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 16 Sep 2014 11:51:46 -0700 Subject: [PATCH 009/315] [SQL][DOCS] Improve section on thrift-server Taken from liancheng's updates. Merged conflicts with #2316. Author: Michael Armbrust Closes #2384 from marmbrus/sqlDocUpdate and squashes the following commits: 2db6319 [Michael Armbrust] @liancheng's updates --- docs/sql-programming-guide.md | 58 ++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 8d41fdec699e9..c498b41c43380 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -128,7 +128,7 @@ feature parity with a HiveContext. -The specific variant of SQL that is used to parse queries can also be selected using the +The specific variant of SQL that is used to parse queries can also be selected using the `spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the @@ -139,7 +139,7 @@ default is "hiveql", though "sql" is also available. Since the HiveQL parser is Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section +Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section describes the various methods for loading data into a SchemaRDD. ## RDDs @@ -152,7 +152,7 @@ while writing your Spark application. The second method for creating SchemaRDDs is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows you to construct SchemaRDDs when the columns and their types are not known until runtime. - + ### Inferring the Schema Using Reflection
    @@ -193,7 +193,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -480,7 +480,7 @@ for name in names.collect(): [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. +of the original data. ### Loading Data Programmatically @@ -562,7 +562,7 @@ for teenName in teenNames.collect():
    -
    + ### Configuration @@ -808,7 +808,7 @@ memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove Note that if you call `cache` rather than `cacheTable`, tables will _not_ be cached using the in-memory columnar format, and therefore `cacheTable` is strongly recommended for this use case. -Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running +Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running `SET key=value` commands using SQL. @@ -881,10 +881,32 @@ To start the JDBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh -The default port the server listens on is 10000. To listen on customized host and port, please set -the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may -run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can -use beeline to test the Thrift JDBC server: +This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to +specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of +all available options. By default, the server listens on localhost:10000. You may override this +bahaviour via either environment variables, i.e.: + +{% highlight bash %} +export HIVE_SERVER2_THRIFT_PORT= +export HIVE_SERVER2_THRIFT_BIND_HOST= +./sbin/start-thriftserver.sh \ + --master \ + ... +``` +{% endhighlight %} + +or system properties: + +{% highlight bash %} +./sbin/start-thriftserver.sh \ + --hiveconf hive.server2.thrift.port= \ + --hiveconf hive.server2.thrift.bind.host= \ + --master + ... +``` +{% endhighlight %} + +Now you can use beeline to test the Thrift JDBC server: ./bin/beeline @@ -930,7 +952,7 @@ SQL deprecates this property in favor of `spark.sql.shuffle.partitions`, whose d is 200. Users may customize this property via `SET`: SET spark.sql.shuffle.partitions=10; - SELECT page, count(*) c + SELECT page, count(*) c FROM logs_last_month_cached GROUP BY page ORDER BY c DESC LIMIT 10; @@ -1139,7 +1161,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c
    All data types of Spark SQL are located in the package `org.apache.spark.sql`. -You can access them by doing +You can access them by doing {% highlight scala %} import org.apache.spark.sql._ {% endhighlight %} @@ -1245,7 +1267,7 @@ import org.apache.spark.sql._
    - - + + + + + + From 14f222f7f76cc93633aae27a94c0e556e289ec56 Mon Sep 17 00:00:00 2001 From: Qiping Li Date: Thu, 9 Oct 2014 01:36:58 -0700 Subject: [PATCH 238/315] [SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training Currently, the implementation does one unnecessary aggregation step. The aggregation step for level L (to choose splits) gives enough information to set the predictions of any leaf nodes at level L+1. We can use that info and skip the aggregation step for the last level of the tree (which only has leaf nodes). ### Implementation Details Each node now has a `impurity` field and the `predict` is changed from type `Double` to type `Predict`(this can be used to compute predict probability in the future) When compute best splits for each node, we also compute impurity and predict for the child nodes, which is used to constructed newly allocated child nodes. So at level L, we have set impurity and predict for nodes at level L +1. If level L+1 is the last level, then we can avoid aggregation. What's more, calculation of parent impurity in Top nodes for each tree needs to be treated differently because we have to compute impurity and predict for them first. In `binsToBestSplit`, if current node is top node(level == 0), we calculate impurity and predict first. after finding best split, top node's predict and impurity is set to the calculated value. Non-top nodes's impurity and predict are already calculated and don't need to be recalculated again. I have considered to add a initialization step to set top nodes' impurity and predict and then we can treat all nodes in the same way, but this will need a lot of duplication of code(all the code to do seq operation(BinSeqOp) needs to be duplicated), so I choose the current way. CC mengxr manishamde jkbradley, please help me review this, thanks. Author: Qiping Li Closes #2708 from chouqin/avoid-agg and squashes the following commits: 8e269ea [Qiping Li] adjust code and comments eefeef1 [Qiping Li] adjust comments and check child nodes' impurity c41b1b6 [Qiping Li] fix pyspark unit test 7ad7a71 [Qiping Li] fix unit test 822c912 [Qiping Li] add comments and unit test e41d715 [Qiping Li] fix bug in test suite 6cc0333 [Qiping Li] SPARK-3158: Avoid 1 extra aggregation for DecisionTree training --- .../spark/mllib/tree/DecisionTree.scala | 97 +++++++++++------ .../tree/model/InformationGainStats.scala | 9 +- .../apache/spark/mllib/tree/model/Node.scala | 37 +++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 102 ++++++++++++++++-- 4 files changed, 197 insertions(+), 48 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b311d10023894..03eeaa707715b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging { Some(mutableNodeToFeatures.toMap) } + // array of nodes to train indexed by node index in group + val nodes = new Array[Node](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + // Calculate best splits for all nodes in the group timer.start("chooseSplits") @@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging { // find best split for each node val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(aggStats, splits, featuresForNode) + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats, predict)) }.collectAsMap() @@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging { // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) assert(node.id == nodeIndex) - node.predict = predict.predict + node.predict = predict node.isLeaf = isLeaf node.stats = Some(stats) + node.impurity = stats.impurity logDebug("Node = " + node) if (!isLeaf) { node.split = Some(split) - node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) - node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) - nodeQueue.enqueue((treeIndex, node.leftNode.get)) - nodeQueue.enqueue((treeIndex, node.rightNode.get)) + val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), + stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), + stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + } + logDebug("leftChildIndex = " + node.leftNode.get.id + ", impurity = " + stats.leftImpurity) logDebug("rightChildIndex = " + node.rightNode.get.id + @@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): InformationGainStats = { + metadata: DecisionTreeMetadata, + impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - - val impurity = parentNodeAgg.calculate() - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging { return InformationGainStats.invalidInformationGainStats } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + // calculate left and right predict + val leftPredict = calculatePredict(leftImpurityCalculator) + val rightPredict = calculatePredict(rightImpurityCalculator) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, + leftPredict, rightPredict) + } + + private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { + val predict = impurityCalculator.predict + val prob = impurityCalculator.prob(predict) + new Predict(predict, prob) } /** @@ -657,17 +685,17 @@ object DecisionTree extends Serializable with Logging { * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split * @param rightImpurityCalculator right node aggregates for a split - * @return predict value for current node + * @return predict value and impurity for current node */ - private def calculatePredict( + private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): Predict = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) + val predict = calculatePredict(parentNodeAgg) + val impurity = parentNodeAgg.calculate() - new Predict(predict, prob) + (predict, impurity) } /** @@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging { private def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], - featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + featuresForNode: Option[Array[Int]], + node: Node): (Split, InformationGainStats, Predict) = { - // calculate predict only once - var predict: Option[Predict] = None + // calculate predict and impurity if current node is top node + val level = Node.indexToLevel(node.id) + var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { + None + } else { + Some((node.predict, node.impurity)) + } // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = @@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - assert(predict.isDefined, "must calculate predict for each node") - - (bestSplit, bestSplitStats, predict.get) + (bestSplit, bestSplitStats, predictWithImpurity.get._1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index a89e71e115806..9a50ecb550c38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity + * @param leftPredict left node predict + * @param rightPredict right node predict */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double) extends Serializable { + val rightImpurity: Double, + val leftPredict: Predict, + val rightPredict: Predict) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" @@ -58,5 +62,6 @@ private[tree] object InformationGainStats { * denote that current split doesn't satisfies minimum info gain or * minimum number of instances per node. */ - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, + new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 56c3e25d9285f..2179da8dbe03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector * * @param id integer node id, from 1 * @param predict predicted value at the node - * @param isLeaf whether the leaf is a node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf * @param split split to calculate left and right nodes * @param leftNode left child * @param rightNode right child @@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - var predict: Double, + var predict: Predict, + var impurity: Double, var isLeaf: Boolean, var split: Option[Split], var leftNode: Option[Node], @@ -49,7 +51,7 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "split = " + split + ", stats = " + stats + "impurity = " + impurity + "split = " + split + ", stats = " + stats /** * build the left node and right nodes if not leaf @@ -62,6 +64,7 @@ class Node ( logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) + logDebug("impurity = " + impurity) if (!isLeaf) { leftNode = Some(nodes(Node.leftChildIndex(id))) rightNode = Some(nodes(Node.rightChildIndex(id))) @@ -77,7 +80,7 @@ class Node ( */ def predict(features: Vector) : Double = { if (isLeaf) { - predict + predict.predict } else{ if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { @@ -109,7 +112,7 @@ class Node ( } else { Some(rightNode.get.deepCopy()) } - new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) } /** @@ -154,7 +157,7 @@ class Node ( } val prefix: String = " " * indentFactor if (isLeaf) { - prefix + s"Predict: $predict\n" + prefix + s"Predict: ${predict.predict}\n" } else { prefix + s"If ${splitToString(split.get, left=true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + @@ -170,7 +173,27 @@ private[tree] object Node { /** * Return a node with the given node id (but nothing else set). */ - def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, + false, None, None, None, None) + + /** + * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. + * This is used in `DecisionTree.findBestSplits` to construct child nodes + * after finding the best splits for parent nodes. + * Other fields are set at next level. + * @param nodeIndex integer node id, from 1 + * @param predict predicted value at the node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf + * @return new node instance + */ + def apply( + nodeIndex: Int, + predict: Predict, + impurity: Double, + isLeaf: Boolean): Node = { + new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) + } /** * Return the index of the left child of this node. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a48ed71a1c5fc..98a72b0c4d750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) assert(stats.impurity > 0.2) } @@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 0.6) + assert(rootNode.predict.predict === 0.6) assert(stats.impurity > 0.2) } @@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 0) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Second level node building with vs. without groups") { @@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats1.impurity === stats2.impurity) assert(stats1.leftImpurity === stats2.leftImpurity) assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict === children2(i).predict) + assert(children1(i).predict.predict === children2(i).predict.predict) } } @@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) } + + test("Avoid aggregation on the last level") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } } object DecisionTreeSuite { From 1e0aa4deba65aa1241b9a30edb82665eae27242f Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 9 Oct 2014 09:22:32 -0700 Subject: [PATCH 239/315] [Minor] use norm operator after breeze 0.10 upgrade cc mengxr Author: GuoQiang Li Closes #2730 from witgo/SPARK-3856 and squashes the following commits: 2cffce1 [GuoQiang Li] use norm operator after breeze 0.10 upgrade --- .../spark/mllib/feature/NormalizerSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index fb76dccfdf79e..2bf9d9816ae45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite +import breeze.linalg.{norm => brzNorm} + import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -50,10 +52,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5) assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) @@ -77,10 +79,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5) assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) From 73bf3f2e0c03216aa29c25fea2d97205b5977903 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 9 Oct 2014 11:27:21 -0700 Subject: [PATCH 240/315] [SPARK-3741] Make ConnectionManager propagate errors properly and add mo... ...re logs to avoid Executors swallowing errors This PR made the following changes: * Register a callback to `Connection` so that the error will be propagated properly. * Add more logs so that the errors won't be swallowed by Executors. * Use trySuccess/tryFailure because `Promise` doesn't allow to call success/failure more than once. Author: zsxwing Closes #2593 from zsxwing/SPARK-3741 and squashes the following commits: 1d5aed5 [zsxwing] Fix naming 0b8a61c [zsxwing] Merge branch 'master' into SPARK-3741 764aec5 [zsxwing] [SPARK-3741] Make ConnectionManager propagate errors properly and add more logs to avoid Executors swallowing errors --- .../apache/spark/network/nio/Connection.scala | 35 +-- .../spark/network/nio/ConnectionManager.scala | 206 +++++++++++++----- 2 files changed, 172 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index f368209980f93..4f6f5e235811d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,11 +20,14 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList import org.apache.spark._ +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, @@ -51,7 +54,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, @volatile private var closed = false var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null + val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() @@ -130,20 +133,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, onCloseCallback = callback } - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback + def onException(callback: (Connection, Throwable) => Unit) { + onExceptionCallbacks.add(callback) } def onKeyInterestChange(callback: (Connection, Int) => Unit) { onKeyInterestChangeCallback = callback } - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) + def callOnExceptionCallbacks(e: Throwable) { + onExceptionCallbacks foreach { + callback => + try { + callback(this, e) + } catch { + case NonFatal(e) => { + logWarning("Ignored error in onExceptionCallback", e) + } + } } } @@ -323,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logError("Error connecting to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } } @@ -348,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } true @@ -393,7 +400,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } @@ -420,7 +427,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, case e: Exception => logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() } @@ -577,7 +584,7 @@ private[spark] class ReceivingConnection( } catch { case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 01cd27a907eea..6b00190c5eccc 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,8 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.Utils +import scala.util.Try +import scala.util.control.NonFatal private[nio] class ConnectionManager( port: Int, @@ -51,14 +53,23 @@ private[nio] class ConnectionManager( class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { + completionHandler: Try[Message] => Unit) { - /** This is non-None if message has been ack'd */ - var ackMessage: Option[Message] = None + def success(ackMessage: Message) { + if (ackMessage == null) { + failure(new NullPointerException) + } + else { + completionHandler(scala.util.Success(ackMessage)) + } + } - def markDone(ackMessage: Option[Message]) { - this.ackMessage = ackMessage - completionHandler(this) + def failWithoutAck() { + completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) + } + + def failure(e: Throwable) { + completionHandler(scala.util.Failure(e)) } } @@ -72,14 +83,32 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) + Utils.namedThreadFactory("handle-message-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleMessageExecutor is not handled properly", t) + } + } + + } private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) + Utils.namedThreadFactory("handle-read-write-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleReadWriteExecutor is not handled properly", t) + } + } + + } // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -153,17 +182,24 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -187,16 +223,23 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -213,19 +256,25 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { + try { + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } catch { + case NonFatal(e) => { + logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) } } ) } @@ -246,16 +295,16 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { try { - conn.callOnExceptionCallback(e) + conn.callOnExceptionCallbacks(e) } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } try { conn.close() } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } } }) @@ -448,7 +497,7 @@ private[nio] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.markDone(None) + status.failWithoutAck() }) messageStatuses.retain((i, status) => { @@ -477,7 +526,7 @@ private[nio] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.markDone(None) + s.failWithoutAck() } messageStatuses.retain((i, status) => { @@ -492,7 +541,7 @@ private[nio] class ConnectionManager( } } - def handleConnectionError(connection: Connection, e: Exception) { + def handleConnectionError(connection: Connection, e: Throwable) { logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) @@ -510,9 +559,17 @@ private[nio] class ConnectionManager( val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + try { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } catch { + case NonFatal(e) => { + logError("Error when handling messages from " + + connection.getRemoteConnectionManagerId(), e) + connection.callOnExceptionCallbacks(e) + } + } } } handleMessageExecutor.execute(runnable) @@ -651,7 +708,7 @@ private[nio] class ConnectionManager( messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status.markDone(Some(message)) + status.success(message) } case None => { /** @@ -770,6 +827,12 @@ private[nio] class ConnectionManager( val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId, securityManager) + newConnection.onException { + case (conn, e) => { + logError("Exception while sending message.", e) + reportSendingMessageFailure(message.id, e) + } + } logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -782,13 +845,36 @@ private[nio] class ConnectionManager( "connectionid: " + connection.connectionId) if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) + try { + checkSendAuthFirst(connectionManagerId, connection) + } catch { + case NonFatal(e) => { + reportSendingMessageFailure(message.id, e) + } + } } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) wakeupSelector() } + private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(messageId) + s match { + case Some(msgStatus) => { + messageStatuses -= messageId + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.failure(e) + } + case None => { + logError("no messageStatus for failed message id: " + messageId) + } + } + } + } + private def wakeupSelector() { selector.wakeup() } @@ -807,9 +893,11 @@ private[nio] class ConnectionManager( override def run(): Unit = { messageStatuses.synchronized { messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } }) } } @@ -817,15 +905,23 @@ private[nio] class ConnectionManager( val status = new MessageStatus(message, connectionManagerId, s => { timeoutTask.cancel() - s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd - promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) - case Some(ackMessage) => + s match { + case scala.util.Failure(e) => + // Indicates a failure where we either never sent or never got ACK'd + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } + case scala.util.Success(ackMessage) => if (ackMessage.hasError) { - promise.failure( - new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + val e = new IOException( + "sendMessageReliably failed with ACK that signalled a remote error") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } } else { - promise.success(ackMessage) + if (!promise.trySuccess(ackMessage)) { + logWarning("Drop ackMessage because promise is completed") + } } } }) From b77a02f41c60d869f48b65e72ed696c05b30bc48 Mon Sep 17 00:00:00 2001 From: Vida Ha Date: Thu, 9 Oct 2014 13:13:31 -0700 Subject: [PATCH 241/315] [SPARK-3752][SQL]: Add tests for different UDF's Author: Vida Ha Closes #2621 from vidaha/vida/SPARK-3752 and squashes the following commits: d7fdbbc [Vida Ha] Add tests for different UDF's --- .../hive/execution/UDFIntegerToString.java | 26 ++++ .../sql/hive/execution/UDFListListInt.java | 51 ++++++++ .../sql/hive/execution/UDFListString.java | 38 ++++++ .../sql/hive/execution/UDFStringString.java | 26 ++++ .../sql/hive/execution/UDFTwoListList.java | 28 +++++ .../sql/hive/execution/HiveUdfSuite.scala | 111 +++++++++++++++--- 6 files changed, 265 insertions(+), 15 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java new file mode 100644 index 0000000000000..6c4f378bc5471 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFIntegerToString extends UDF { + public String evaluate(Integer i) { + return i.toString(); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java new file mode 100644 index 0000000000000..d2d39a8c4dc28 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; + +public class UDFListListInt extends UDF { + /** + * + * @param obj + * SQL schema: array> + * Java Type: List> + * @return + */ + public long evaluate(Object obj) { + if (obj == null) { + return 0l; + } + List listList = (List) obj; + long retVal = 0; + for (List aList : listList) { + @SuppressWarnings("unchecked") + List list = (List) aList; + @SuppressWarnings("unchecked") + Integer someInt = (Integer) list.get(1); + try { + retVal += (long) (someInt.intValue()); + } catch (NullPointerException e) { + System.out.println(e); + } + } + return retVal; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java new file mode 100644 index 0000000000000..efd34df293c88 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; +import org.apache.commons.lang.StringUtils; + +public class UDFListString extends UDF { + + public String evaluate(Object a) { + if (a == null) { + return null; + } + @SuppressWarnings("unchecked") + List s = (List) a; + + return StringUtils.join(s, ','); + } + + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java new file mode 100644 index 0000000000000..a369188d471e8 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFStringString extends UDF { + public String evaluate(String s1, String s2) { + return s1 + " " + s2; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java new file mode 100644 index 0000000000000..0165591a7ce78 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFTwoListList extends UDF { + public String evaluate(Object o1, Object o2) { + UDFListListInt udf = new UDFListListInt(); + + return String.format("%s, %s", udf.evaluate(o1), udf.evaluate(o2)); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index e4324e9528f9b..872f28d514efe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -17,33 +17,37 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataOutput, DataInput} +import java.io.{DataInput, DataOutput} import java.util import java.util.Properties -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector} - -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject - -import org.apache.spark.sql.Row +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.io.Writable +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ + +import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) +// Case classes for the custom UDF's. +case class IntegerCaseClass(i: Int) +case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) +case class StringCaseClass(s: String) +case class ListStringCaseClass(l: Seq[String]) + /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends HiveComparisonTest { +class HiveUdfSuite extends QueryTest { + import TestHive._ test("spark sql udf test that returns a struct") { registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest { } test("SPARK-2693 udaf aggregates test") { - assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) + checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) + } + + test("UDFIntegerToString") { + val testData = TestHive.sparkContext.parallelize( + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) + testData.registerTempTable("integerTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + checkAnswer( + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + Seq(Seq("1"), Seq("2"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + + TestHive.reset() + } + + test("UDFListListInt") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) + testData.registerTempTable("listListIntTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + checkAnswer( + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + Seq(Seq(0), Seq(2), Seq(13))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + + TestHive.reset() + } + + test("UDFListString") { + val testData = TestHive.sparkContext.parallelize( + ListStringCaseClass(Seq("a", "b", "c")) :: + ListStringCaseClass(Seq("d", "e")) :: Nil) + testData.registerTempTable("listStringTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + checkAnswer( + sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + Seq(Seq("a,b,c"), Seq("d,e"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + + TestHive.reset() + } + + test("UDFStringString") { + val testData = TestHive.sparkContext.parallelize( + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) + testData.registerTempTable("stringTable") + + sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + checkAnswer( + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + Seq(Seq("hello world"), Seq("hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + + TestHive.reset() + } + + test("UDFTwoListList") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: + Nil) + testData.registerTempTable("TwoListTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + checkAnswer( + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + + TestHive.reset() } } From 752e90f15e0bb82d283f05eff08df874b48caed9 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Thu, 9 Oct 2014 12:59:14 -0700 Subject: [PATCH 242/315] [SPARK-3711][SQL] Optimize where in clause filter queries The In case class is replaced by a InSet class in case all the filters are literals, which uses a hashset instead of Sequence, thereby giving significant performance improvement (earlier the seq was using a worst case linear match (exists method) since expressions were assumed in the filter list) . Maximum improvement should be visible in case small percentage of large data matches the filter list. Author: Yash Datta Closes #2561 from saucam/branch-1.1 and squashes the following commits: 4bf2d19 [Yash Datta] SPARK-3711: 1. Fix code style and import order 2. Fix optimization condition 3. Add tests for null in filter list 4. Add test case that optimization is not triggered in case of attributes in filter list afedbcd [Yash Datta] SPARK-3711: 1. Add test cases for InSet class in ExpressionEvaluationSuite 2. Add class OptimizedInSuite on the lines of ConstantFoldingSuite, for the optimized In clause 0fc902f [Yash Datta] SPARK-3711: UnaryMinus will be handled by constantFolding bd84c67 [Yash Datta] SPARK-3711: Incorporate review comments. Move optimization of In clause to Optimizer.scala by adding a rule. Add appropriate comments 430f5d1 [Yash Datta] SPARK-3711: Optimize the filter list in case of negative values as well bee98aa [Yash Datta] SPARK-3711: Optimize where in clause filter queries --- .../sql/catalyst/expressions/predicates.scala | 19 ++++- .../sql/catalyst/optimizer/Optimizer.scala | 18 ++++- .../ExpressionEvaluationSuite.scala | 21 +++++ .../catalyst/optimizer/OptimizeInSuite.scala | 76 +++++++++++++++++++ 4 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 329af332d0fa1..1e22b2d03c672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType - object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = apply(BindReferences.bindReference(expression, inputSchema)) @@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } +/** + * Optimized version of In clause, when all filter values of In clause are + * static. + */ +case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) + extends Predicate { + + def children = child + + def nullable = true // TODO: Figure out correct nullability semantics of IN. + override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + + override def eval(input: Row): Any = { + hset.contains(value.eval(input)) + } +} + case class And(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "&&" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 636d0b95583e4..3693b41404fd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + OptimizeIn) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -273,6 +275,20 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]] + * which is much faster + */ +object OptimizeIn extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(null)) + InSet(v, HashSet() ++ hSet, v +: list) + } + } +} + /** * Simplifies boolean expressions where the answer can be determined without evaluating both sides. * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 63931af4bac3d..692ed78a7292c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp +import scala.collection.immutable.HashSet + import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.scalautils.TripleEqualsSupport.Spread import org.apache.spark.sql.catalyst.types._ + /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -145,6 +148,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) } + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + val s = Seq(one, two) + val nullS = Seq(one, two, null) + checkEvaluation(InSet(one, hS, one +: s), true) + checkEvaluation(InSet(two, hS, two +: s), true) + checkEvaluation(InSet(two, nS, two +: nullS), true) + checkEvaluation(InSet(nl, nS, nl +: nullS), true) + checkEvaluation(InSet(three, hS, three +: s), false) + checkEvaluation(InSet(three, nS, three +: nullS), false) + checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) + } + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala new file mode 100644 index 0000000000000..97a78ec971c39 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types._ + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class OptimizeInSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification, + OptimizeIn) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("OptimizedIn test: In clause optimized to InSet") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, + UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: In clause not optimized in case filter has attributes") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + comparePlans(optimized, correctAnswer) + } +} From 2c8851343a2e4d1d5b3a2b959eaa651a92982a72 Mon Sep 17 00:00:00 2001 From: scwf Date: Thu, 9 Oct 2014 13:22:36 -0700 Subject: [PATCH 243/315] [SPARK-3806][SQL] Minor fix for CliSuite To fix two issues in CliSuite 1 CliSuite throw IndexOutOfBoundsException: Exception in thread "Thread-6" java.lang.IndexOutOfBoundsException: 6 at scala.collection.mutable.ResizableArray$class.apply(ResizableArray.scala:43) at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:47) at org.apache.spark.sql.hive.thriftserver.CliSuite.org$apache$spark$sql$hive$thriftserver$CliSuite$$captureOutput$1(CliSuite.scala:67) at org.apache.spark.sql.hive.thriftserver.CliSuite$$anonfun$4.apply(CliSuite.scala:78) at org.apache.spark.sql.hive.thriftserver.CliSuite$$anonfun$4.apply(CliSuite.scala:78) at scala.sys.process.ProcessLogger$$anon$1.out(ProcessLogger.scala:96) at scala.sys.process.BasicIO$$anonfun$processOutFully$1.apply(BasicIO.scala:135) at scala.sys.process.BasicIO$$anonfun$processOutFully$1.apply(BasicIO.scala:135) at scala.sys.process.BasicIO$.readFully$1(BasicIO.scala:175) at scala.sys.process.BasicIO$.processLinesFully(BasicIO.scala:179) at scala.sys.process.BasicIO$$anonfun$processFully$1.apply(BasicIO.scala:164) at scala.sys.process.BasicIO$$anonfun$processFully$1.apply(BasicIO.scala:162) at scala.sys.process.ProcessBuilderImpl$Simple$$anonfun$3.apply$mcV$sp(ProcessBuilderImpl.scala:73) at scala.sys.process.ProcessImpl$Spawn$$anon$1.run(ProcessImpl.scala:22) Actually, it is the Mutil-Threads lead to this problem. 2 Using ```line.startsWith``` instead ```line.contains``` to assert expected answer. This is a tiny bug in CliSuite, for test case "Simple commands", there is a expected answers "5", if we use ```contains``` that means output like "14/10/06 11:```5```4:36 INFO CliDriver: Time taken: 1.078 seconds" or "14/10/06 11:54:36 INFO StatsReportListener: 0% ```5```% 10% 25% 50% 75% 90% 95% 100%" will make the assert true. Author: scwf Closes #2666 from scwf/clisuite and squashes the following commits: 11430db [scwf] fix-clisuite --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3475c2c9db080..d68dd090b5e6c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -62,9 +62,11 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def captureOutput(source: String)(line: String) { buffer += s"$source> $line" - if (line.contains(expectedAnswers(next.get()))) { - if (next.incrementAndGet() == expectedAnswers.size) { - foundAllExpectedAnswers.trySuccess(()) + if (next.get() < expectedAnswers.size) { + if (line.startsWith(expectedAnswers(next.get()))) { + if (next.incrementAndGet() == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) + } } } } From e7edb723d22869f228b838fd242bf8e6fe73ee19 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Thu, 9 Oct 2014 13:46:26 -0700 Subject: [PATCH 244/315] [SPARK-3868][PySpark] Hard to recognize which module is tested from unit-tests.log ./python/run-tests script display messages about which test it is running currently on stdout but not write them on unit-tests.log. It is harder for us to recognize what test programs were executed and which test was failed. Author: cocoatomo Closes #2724 from cocoatomo/issues/3868-display-testing-module-name and squashes the following commits: c63d9fa [cocoatomo] [SPARK-3868][PySpark] Hard to recognize which module is tested from unit-tests.log --- python/run-tests | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/run-tests b/python/run-tests index 63395f72788f9..f6a96841175e8 100755 --- a/python/run-tests +++ b/python/run-tests @@ -25,16 +25,17 @@ FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" cd "$FWDIR/python" FAILED=0 +LOG_FILE=unit-tests.log -rm -f unit-tests.log +rm -f $LOG_FILE # Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL rm -rf metastore warehouse function run_test() { - echo "Running test: $1" + echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE FAILED=$((PIPESTATUS[0]||$FAILED)) From ec4d40e48186af18e25517e0474020720645f583 Mon Sep 17 00:00:00 2001 From: Mike Timper Date: Thu, 9 Oct 2014 14:02:27 -0700 Subject: [PATCH 245/315] [SPARK-3853][SQL] JSON Schema support for Timestamp fields In JSONRDD.scala, add 'case TimestampType' in the enforceCorrectType function and a toTimestamp function. Author: Mike Timper Closes #2720 from mtimper/master and squashes the following commits: 9386ab8 [Mike Timper] Fix and tests for SPARK-3853 --- .../main/scala/org/apache/spark/sql/json/JsonRDD.scala | 10 ++++++++++ .../scala/org/apache/spark/sql/json/JsonSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0f27fd13e7379..fbc2965e61e92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.json import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal +import java.sql.Timestamp import com.fasterxml.jackson.databind.ObjectMapper @@ -361,6 +362,14 @@ private[sql] object JsonRDD extends Logging { } } + private def toTimestamp(value: Any): Timestamp = { + value match { + case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) + case value: java.lang.Long => new Timestamp(value) + case value: java.lang.String => Timestamp.valueOf(value) + } + } + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ if (value == null) { null @@ -377,6 +386,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) + case TimestampType => toTimestamp(value) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 685e788207725..3cfcb2b1aa993 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ +import java.sql.Timestamp + class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -50,6 +52,12 @@ class JsonSuite extends QueryTest { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) + + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(new Timestamp(intNumber.toLong), + enforceCorrectType(intNumber.toLong, TimestampType)) + val strDate = "2014-09-30 12:34:56" + checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType)) } test("Get compatible type") { From 1faa1135a3fc0acd89f934f01a4a2edefcb93d33 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 9 Oct 2014 14:50:36 -0700 Subject: [PATCH 246/315] Revert "[SPARK-2805] Upgrade to akka 2.3.4" This reverts commit b9df8af62e8d7b263a668dfb6e9668ab4294ea37. --- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- pom.xml | 2 +- .../spark/streaming/InputStreamsSuite.scala | 71 +++++++++++++++++++ 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index f2687ce6b42b4..065ddda50e65e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 98a93d1fcb2a3..32790053a6be8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 63a8ac817b618..6d0d0bbe5ecec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index cbc0bd178d894..1fef79ad1001f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) + new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) + new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/pom.xml b/pom.xml index 3b6d4ecbae2c1..7756c89b00cad 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.3.4-spark + 2.2.3-shaded-protobuf 1.7.5 1.2.17 1.0.4 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 6107fcdc447b6..952a74fd5f6de 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.streaming import akka.actor.Actor +import akka.actor.IO +import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -142,6 +144,59 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } + // TODO: This test works in IntelliJ but not through SBT + ignore("actor input stream") { + // Start the server + val testServer = new TestServer() + val port = testServer.port + testServer.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", + // Had to pass the local value of port to prevent from closing over entire scope + StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) + outputStream.register() + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = 1 to 9 + val expectedOutput = input.map(x => x.toString) + Thread.sleep(1000) + for (i <- 0 until input.size) { + testServer.send(input(i).toString) + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } + + test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -323,6 +378,22 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } +/** This is an actor for testing actor input stream */ +class TestActor(port: Int) extends Actor with ActorHelper { + + def bytesToString(byteString: ByteString) = byteString.utf8String + + override def preStart(): Unit = { + @deprecated("suppress compile time deprecation warning", "1.0.0") + val unit = IOManager(context.system).connect(new InetSocketAddress(port)) + } + + def receive = { + case IO.Read(socket, bytes) => + store(bytesToString(bytes)) + } +} + /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { From 1c7f0ab302de9f82b1bd6da852d133823bc67c66 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 9 Oct 2014 14:57:27 -0700 Subject: [PATCH 247/315] [SPARK-3339][SQL] Support for skipping json lines that fail to parse This PR aims to provide a way to skip/query corrupt JSON records. To do so, we introduce an internal column to hold corrupt records (the default name is `_corrupt_record`. This name can be changed by setting the value of `spark.sql.columnNameOfCorruptRecord`). When there is a parsing error, we will put the corrupt record in its unparsed format to the internal column. Users can skip/query this column through SQL. * To query those corrupt records ``` -- For Hive parser SELECT `_corrupt_record` FROM jsonTable WHERE `_corrupt_record` IS NOT NULL -- For our SQL parser SELECT _corrupt_record FROM jsonTable WHERE _corrupt_record IS NOT NULL ``` * To skip corrupt records and query regular records ``` -- For Hive parser SELECT field1, field2 FROM jsonTable WHERE `_corrupt_record` IS NULL -- For our SQL parser SELECT field1, field2 FROM jsonTable WHERE _corrupt_record IS NULL ``` Generally, it is not recommended to change the name of the internal column. If the name has to be changed to avoid possible name conflicts, you can use `sqlContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, )` or `sqlContext.sql(SET spark.sql.columnNameOfCorruptRecord=)`. Author: Yin Huai Closes #2680 from yhuai/corruptJsonRecord and squashes the following commits: 4c9828e [Yin Huai] Merge remote-tracking branch 'upstream/master' into corruptJsonRecord 309616a [Yin Huai] Change the default name of corrupt record to "_corrupt_record". b4a3632 [Yin Huai] Merge remote-tracking branch 'upstream/master' into corruptJsonRecord 9375ae9 [Yin Huai] Set the column name of corrupt json record back to the default one after the unit test. ee584c0 [Yin Huai] Provide a way to query corrupt json records as unparsed strings. --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 ++ .../org/apache/spark/sql/SQLContext.scala | 14 +++-- .../spark/sql/api/java/JavaSQLContext.scala | 16 +++-- .../org/apache/spark/sql/json/JsonRDD.scala | 30 ++++++--- .../org/apache/spark/sql/json/JsonSuite.scala | 62 ++++++++++++++++++- .../apache/spark/sql/json/TestJsonData.scala | 9 +++ 6 files changed, 116 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f6f4cf3b80d41..07e6e2eccddf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -35,6 +35,7 @@ private[spark] object SQLConf { val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -131,6 +132,9 @@ private[sql] trait SQLConf { private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + private[spark] def columnNameOfCorruptRecord: String = + getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record") + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 35561cac3e5e1..014e1e2826724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -195,9 +195,12 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord val appliedSchema = - Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + Option(schema).getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } @@ -206,8 +209,11 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { - val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord + val appliedSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index c006c4330ff66..f8171c3be3207 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -148,8 +148,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * It goes through the entire dataset once to determine the schema. */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { - val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord + val appliedScalaSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord)) + val scalaRowRDD = + JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) @@ -162,10 +166,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { */ @Experimental def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord val appliedScalaSchema = Option(asScalaDataType(schema)).getOrElse( - JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema( + json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType] + val scalaRowRDD = JsonRDD.jsonStringToRow( + json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index fbc2965e61e92..61ee960aad9d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal import java.sql.Timestamp +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -35,16 +36,19 @@ private[sql] object JsonRDD extends Logging { private[sql] def jsonStringToRow( json: RDD[String], - schema: StructType): RDD[Row] = { - parseJson(json).map(parsed => asRow(parsed, schema)) + schema: StructType, + columnNameOfCorruptRecords: String): RDD[Row] = { + parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) } private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0): StructType = { + samplingRatio: Double = 1.0, + columnNameOfCorruptRecords: String): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) + val allKeys = + parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) createSchema(allKeys) } @@ -274,7 +278,9 @@ private[sql] object JsonRDD extends Logging { case atom => atom } - private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { + private def parseJson( + json: RDD[String], + columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], // ObjectMapper will not return BigDecimal when // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled @@ -289,12 +295,16 @@ private[sql] object JsonRDD extends Logging { // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() iter.flatMap { record => - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - } + try { + val parsed = mapper.readValue(record, classOf[Object]) match { + case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil + case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + } - parsed + parsed + } catch { + case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil + } } }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 3cfcb2b1aa993..7bb08f1b513ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import java.sql.Timestamp @@ -644,7 +646,65 @@ class JsonSuite extends QueryTest { ("str_a_1", null, null) :: ("str_a_2", null, null) :: (null, "str_b_3", null) :: - ("str_a_4", "str_b_4", "str_c_4") ::Nil + ("str_a_4", "str_b_4", "str_c_4") :: Nil ) } + + test("Corrupt records") { + // Test if we can query corrupt records. + val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val jsonSchemaRDD = jsonRDD(corruptRecords) + jsonSchemaRDD.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonSchemaRDD.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + (null, null, null, "{") :: + (null, null, null, "") :: + (null, null, null, """{"a":1, b:2}""") :: + (null, null, null, """{"a":{, b:3}""") :: + ("str_a_4", "str_b_4", "str_c_4", null) :: + (null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + ("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Seq("{") :: + Seq("") :: + Seq("""{"a":1, b:2}""") :: + Seq("""{"a":{, b:3}""") :: + Seq("]") :: Nil + ) + + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index fc833b8b54e4c..eaca9f0508a12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -143,4 +143,13 @@ object TestJsonData { """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) + + val corruptRecords = + TestSQLContext.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a":1, b:2}""" :: + """{"a":{, b:3}""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """]""" :: Nil) } From 0c0e09f567deb775ee378f5385a16884f68b332d Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 9 Oct 2014 14:59:03 -0700 Subject: [PATCH 248/315] [SPARK-3412][SQL]add missing row api chenghao-intel assigned this to me, check PR #2284 for previous discussion Author: Daoyuan Wang Closes #2529 from adrian-wang/rowapi and squashes the following commits: c6594b2 [Daoyuan Wang] using boxed 7b7e6e3 [Daoyuan Wang] update pattern match 7a39456 [Daoyuan Wang] rename file and refresh getAs[T] 4c18c29 [Daoyuan Wang] remove setAs[T] and null judge 1614493 [Daoyuan Wang] add missing row api --- .../sql/catalyst/expressions/Projection.scala | 15 ++++++++++++++ .../spark/sql/catalyst/expressions/Row.scala | 20 ++++++++++--------- ...ificRow.scala => SpecificMutableRow.scala} | 8 ++++++-- 3 files changed, 32 insertions(+), 11 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{SpecificRow.scala => SpecificMutableRow.scala} (97%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index ef1d12531f109..204904ecf04db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -137,6 +137,9 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -226,6 +229,9 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -309,6 +315,9 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -392,6 +401,9 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -475,6 +487,9 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index d68a4fabeac77..d00ec39774c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -64,6 +64,7 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String + def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = s"[${this.mkString(",")}]" @@ -118,6 +119,7 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException + override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this } @@ -217,19 +219,19 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } - override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value } - override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } - override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value } + override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } + override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } + override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } + override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } + override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } + override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } - override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value } + override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } - override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value } + override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } override def copy() = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 9cbab3d5d0d0d..570379c533e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -233,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def iterator: Iterator[Any] = values.map(_.boxed).iterator - def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String) = update(ordinal, value) - def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] @@ -306,4 +306,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } + + override def getAs[T](i: Int): T = { + values(i).boxed.asInstanceOf[T] + } } From bc3b6cb06153d6b05f311dd78459768b6cf6a404 Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Thu, 9 Oct 2014 15:03:01 -0700 Subject: [PATCH 249/315] [SPARK-3858][SQL] Pass the generator alias into logical plan node The alias parameter is being ignored, which makes it more difficult to specify a qualifier for Generator expressions. Author: Nathan Howell Closes #2721 from NathanHowell/SPARK-3858 and squashes the following commits: 8aa0f43 [Nathan Howell] [SPARK-3858][SQL] Pass the generator alias into logical plan node --- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 2 +- .../test/scala/org/apache/spark/sql/DslQuerySuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 594bf8ffc20e1..948122d42f0e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -360,7 +360,7 @@ class SchemaRDD( join: Boolean = false, outer: Boolean = false, alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan)) + new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) /** * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index d001abb7e1fcc..45e58afe9d9a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -147,6 +147,14 @@ class DslQuerySuite extends QueryTest { (1, 1, 1, 2) :: Nil) } + test("SPARK-3858 generator qualifiers are discarded") { + checkAnswer( + arrayData.as('ad) + .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) + .select("ex.data".attr), + Seq(1, 2, 3, 2, 3, 4).map(Seq(_))) + } + test("average") { checkAnswer( testData2.groupBy()(avg('a)), From ac302052870a650d56f2d3131c27755bb2960ad7 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 9 Oct 2014 15:14:58 -0700 Subject: [PATCH 250/315] [SPARK-3813][SQL] Support "case when" conditional functions in Spark SQL. "case when" conditional function is already supported in Spark SQL but there is no support in SqlParser. So added parser support to it. Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala Closes #2678 from ravipesala/SPARK-3813 and squashes the following commits: 70c75a7 [ravipesala] Fixed styles 713ea84 [ravipesala] Updated as per admin comments 709684f [ravipesala] Changed parser to support case when function. --- .../org/apache/spark/sql/catalyst/SqlParser.scala | 14 ++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 15 +++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 854b5b461bdc8..4662f585cfe15 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -77,10 +77,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CACHE = Keyword("CACHE") + protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val ELSE = Keyword("ELSE") + protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") @@ -122,11 +125,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") + protected val THEN = Keyword("THEN") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") + protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") // Use reflection to find the reserved words defined in this class. @@ -333,6 +338,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers { IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case c ~ "," ~ t ~ "," ~ f => If(c,t,f) } | + CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + (ELSE ~> expression).? <~ END ^^ { + case casePart ~ altPart ~ elsePart => + val altExprs = altPart.flatMap { + case we ~ te => + Seq(casePart.fold(we)(EqualTo(_, we)), te) + } + CaseWhen(altExprs ++ elsePart.toList) + } | (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) } | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b9b196ea5a46a..79de1bb855dbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -680,9 +680,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } - + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + } + + test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { + checkAnswer( + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + } + + test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { + checkAnswer( + sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + } } From 4e9b551a0b807f5a2cc6679165c8be4e88a3d077 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 9 Oct 2014 16:08:07 -0700 Subject: [PATCH 251/315] [SPARK-3772] Allow `ipython` to be used by Pyspark workers; IPython support improvements: This pull request addresses a few issues related to PySpark's IPython support: - Fix the remaining uses of the '-u' flag, which IPython doesn't support (see SPARK-3772). - Change PYSPARK_PYTHON_OPTS to PYSPARK_DRIVER_PYTHON_OPTS, so that the old name is reserved in case we ever want to allow the worker Python options to be customized (this variable was introduced in #2554 and hasn't landed in a release yet, so this doesn't break any compatibility). - Introduce a PYSPARK_DRIVER_PYTHON option that allows the driver to use `ipython` while the workers use a different Python version. - Attempt to use Python 2.7 by default if PYSPARK_PYTHON is not specified. - Retain the old semantics for IPYTHON=1 and IPYTHON_OPTS (to avoid breaking existing example programs). There are more details in a block comment in `bin/pyspark`. Author: Josh Rosen Closes #2651 from JoshRosen/SPARK-3772 and squashes the following commits: 7b8eb86 [Josh Rosen] More changes to PySpark python executable configuration: c4f5778 [Josh Rosen] [SPARK-3772] Allow ipython to be used by Pyspark workers; IPython fixes: --- bin/pyspark | 51 ++++++++++++++----- .../api/python/PythonWorkerFactory.scala | 8 ++- .../apache/spark/deploy/PythonRunner.scala | 4 +- docs/programming-guide.md | 8 +-- 4 files changed, 51 insertions(+), 20 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 6655725ef8e8e..96f30a260a09e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -50,22 +50,47 @@ fi . "$FWDIR"/bin/load-spark-env.sh -# Figure out which Python executable to use +# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` +# executable, while the worker would still be launched using PYSPARK_PYTHON. +# +# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added +# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver. +# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set +# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver +# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython +# and executor Python executables. +# +# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables. + +# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set: +if hash python2.7 2>/dev/null; then + # Attempt to use Python 2.7, if installed: + DEFAULT_PYTHON="python2.7" +else + DEFAULT_PYTHON="python" +fi + +# Determine the Python executable to use for the driver: +if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then + # If IPython options are specified, assume user wants to run IPython + # (for backwards-compatibility) + PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS" + PYSPARK_DRIVER_PYTHON="ipython" +elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}" +fi + +# Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON="ipython" + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then + echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 + exit 1 else - PYSPARK_PYTHON="python" + PYSPARK_PYTHON="$DEFAULT_PYTHON" fi fi export PYSPARK_PYTHON -if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS" -fi - # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" @@ -93,9 +118,9 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_PYTHON" -m doctest $1 + exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else - exec "$PYSPARK_PYTHON" $1 + exec "$PYSPARK_DRIVER_PYTHON" $1 fi exit fi @@ -111,5 +136,5 @@ if [[ "$1" =~ \.py$ ]]; then else # PySpark shell requires special handling downstream export PYSPARK_SHELL=1 - exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS + exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 71bdf0fe1b917..e314408c067e9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") val worker = pb.start() // Redirect worker stdout and stderr @@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 79b4d7ea41a33..af94b05ce3847 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -34,7 +34,8 @@ object PythonRunner { val pythonFile = args(0) val pyFiles = args(1) val otherArgs = args.slice(2, args.length) - val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python")) // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) @@ -57,6 +58,7 @@ object PythonRunner { val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 8e8cc1dd983f8..18420afb27e3c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes, It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To -use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`: +use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`: {% highlight bash %} -$ PYSPARK_PYTHON=ipython ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch +You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark {% endhighlight %} From 2837bf8548db7e9d43f6eefedf5a73feb22daedb Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 9 Oct 2014 17:54:02 -0700 Subject: [PATCH 252/315] [SPARK-3798][SQL] Store the output of a generator in a val This prevents it from changing during serialization, leading to corrupted results. Author: Michael Armbrust Closes #2656 from marmbrus/generateBug and squashes the following commits: efa32eb [Michael Armbrust] Store the output of a generator in a val. This prevents it from changing during serialization. --- .../main/scala/org/apache/spark/sql/execution/Generate.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index c386fd121c5de..38877c28de3a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -39,7 +39,8 @@ case class Generate( child: SparkPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { + // This must be a val since the generator output expr ids are not preserved by serialization. + protected val generatorOutput: Seq[Attribute] = { if (join && outer) { generator.output.map(_.withNullability(true)) } else { @@ -62,7 +63,7 @@ case class Generate( newProjection(child.output ++ nullValues, child.output) val joinProjection = - newProjection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) val joinedRow = new JoinedRow iter.flatMap {row => From 363baacaded56047bcc63276d729ab911e0336cf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 9 Oct 2014 18:21:59 -0700 Subject: [PATCH 253/315] SPARK-3811 [CORE] More robust / standard Utils.deleteRecursively, Utils.createTempDir I noticed a few issues with how temp directories are created and deleted: *Minor* * Guava's `Files.createTempDir()` plus `File.deleteOnExit()` is used in many tests to make a temp dir, but `Utils.createTempDir()` seems to be the standard Spark mechanism * Call to `File.deleteOnExit()` could be pushed into `Utils.createTempDir()` as well, along with this replacement * _I messed up the message in an exception in `Utils` in SPARK-3794; fixed here_ *Bit Less Minor* * `Utils.deleteRecursively()` fails immediately if any `IOException` occurs, instead of trying to delete any remaining files and subdirectories. I've observed this leave temp dirs around. I suggest changing it to continue in the face of an exception and throw one of the possibly several exceptions that occur at the end. * `Utils.createTempDir()` will add a JVM shutdown hook every time the method is called. Even if the subdir is the parent of another parent dir, since this check is inside the hook. However `Utils` manages a set of all dirs to delete on shutdown already, called `shutdownDeletePaths`. A single hook can be registered to delete all of these on exit. This is how Tachyon temp paths are cleaned up in `TachyonBlockManager`. I noticed a few other things that might be changed but wanted to ask first: * Shouldn't the set of dirs to delete be `File`, not just `String` paths? * `Utils` manages the set of `TachyonFile` that have been registered for deletion, but the shutdown hook is managed in `TachyonBlockManager`. Should this logic not live together, and not in `Utils`? it's more specific to Tachyon, and looks a slight bit odd to import in such a generic place. Author: Sean Owen Closes #2670 from srowen/SPARK-3811 and squashes the following commits: 071ae60 [Sean Owen] Update per @vanzin's review da0146d [Sean Owen] Make Utils.deleteRecursively try to delete all paths even when an exception occurs; use one shutdown hook instead of one per method call to delete temp dirs 3a0faa4 [Sean Owen] Standardize on Utils.createTempDir instead of Files.createTempDir --- .../scala/org/apache/spark/TestUtils.scala | 5 +- .../scala/org/apache/spark/util/Utils.scala | 55 +++++++++++++------ .../org/apache/spark/FileServerSuite.scala | 4 +- .../scala/org/apache/spark/FileSuite.scala | 4 +- .../spark/deploy/SparkSubmitSuite.scala | 3 +- .../WholeTextFileRecordReaderSuite.scala | 6 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 21 ++++--- .../scheduler/EventLoggingListenerSuite.scala | 4 +- .../spark/scheduler/ReplayListenerSuite.scala | 4 +- .../spark/storage/DiskBlockManagerSuite.scala | 17 +----- .../apache/spark/util/FileLoggerSuite.scala | 3 +- .../org/apache/spark/util/UtilsSuite.scala | 28 +++++++++- .../spark/mllib/util/MLUtilsSuite.scala | 9 ++- .../spark/repl/ExecutorClassLoaderSuite.scala | 8 +-- .../org/apache/spark/repl/ReplSuite.scala | 4 +- .../spark/streaming/CheckpointSuite.scala | 3 +- .../spark/streaming/InputStreamsSuite.scala | 3 +- .../spark/streaming/MasterFailureTest.scala | 3 +- .../spark/streaming/TestSuiteBase.scala | 5 +- .../spark/deploy/yarn/ClientBaseSuite.scala | 5 +- 20 files changed, 102 insertions(+), 92 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 8ca731038e528..e72826dc25f41 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -26,6 +26,8 @@ import scala.collection.JavaConversions._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import com.google.common.io.Files +import org.apache.spark.util.Utils + /** * Utilities for tests. Included in main codebase since it's used by multiple * projects. @@ -42,8 +44,7 @@ private[spark] object TestUtils { * in order to avoid interference between tests. */ def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) createJar(files, jarFile) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3d307b3c16d3e..07477dd460a4b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -168,6 +168,20 @@ private[spark] object Utils extends Logging { private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + // Add a shutdown hook to delete the temp dirs when the JVM exits + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") { + override def run(): Unit = Utils.logUncaughtExceptions { + logDebug("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + }) + // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { val absolutePath = file.getAbsolutePath() @@ -252,14 +266,6 @@ private[spark] object Utils extends Logging { } registerShutdownDeleteDir(dir) - - // Add a shutdown hook to delete the temp dir when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { - override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. - if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) - } - }) dir } @@ -666,15 +672,30 @@ private[spark] object Utils extends Logging { */ def deleteRecursively(file: File) { if (file != null) { - if (file.isDirectory() && !isSymlink(file)) { - for (child <- listFilesSafely(file)) { - deleteRecursively(child) + try { + if (file.isDirectory && !isSymlink(file)) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(file.getAbsolutePath) + } } - } - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } } } } @@ -713,7 +734,7 @@ private[spark] object Utils extends Logging { */ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { if (!dir.isDirectory) { - throw new IllegalArgumentException("$dir is not a directory!") + throw new IllegalArgumentException(s"$dir is not a directory!") } val filesAndDirs = dir.listFiles() val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 7e18f45de7b5b..a8867020e457d 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.util.jar.{JarEntry, JarOutputStream} -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext._ @@ -41,8 +40,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { override def beforeAll() { super.beforeAll() - tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() + tmpDir = Utils.createTempDir() val testTempDir = new File(tmpDir, "test") testTempDir.mkdir() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 4a53d25012ad9..a2b74c4419d46 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, FileWriter} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} @@ -39,8 +38,7 @@ class FileSuite extends FunSuite with LocalSparkContext { override def beforeEach() { super.beforeEach() - tempDir = Files.createTempDir() - tempDir.deleteOnExit() + tempDir = Utils.createTempDir() } override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4cba90e8f2afe..1cdf50d5c08c7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite import org.scalatest.Matchers -import com.google.common.io.Files class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { @@ -332,7 +331,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { } def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index d5ebfb3f3fae1..12d1c7b2faba6 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -23,8 +23,6 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq -import com.google.common.io.Files - import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite @@ -66,9 +64,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { * 3) Does the contents be the same. */ test("Correctness of WholeTextFileRecordReader.") { - - val dir = Files.createTempDir() - dir.deleteOnExit() + val dir = Utils.createTempDir() println(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 75b01191901b8..3620e251cc139 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -24,13 +24,14 @@ import org.apache.hadoop.util.Progressable import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random -import com.google.common.io.Files import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.spark.{Partitioner, SharedSparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.util.Utils + import org.scalatest.FunSuite class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { @@ -381,14 +382,16 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("zero-partition RDD") { - val emptyDir = Files.createTempDir() - emptyDir.deleteOnExit() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - emptyDir.delete() + val emptyDir = Utils.createTempDir() + try { + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.isEmpty) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } finally { + Utils.deleteRecursively(emptyDir) + } } test("keys and values") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 3efa85431876b..abc300fcffaf9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import scala.collection.mutable import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -51,8 +50,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { private var logDirPath: Path = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "spark-events") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 48114feee6233..e05f373392d4a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} -import com.google.common.io.Files import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -39,8 +38,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { private var testDir: File = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() } after { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index e4522e00a622d..bc5c74c126b74 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,22 +19,13 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.shuffle.hash.HashShuffleManager - -import scala.collection.mutable import scala.language.reflectiveCalls -import akka.actor.Props -import com.google.common.io.Files import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -48,10 +39,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before override def beforeAll() { super.beforeAll() - rootDir0 = Files.createTempDir() - rootDir0.deleteOnExit() - rootDir1 = Files.createTempDir() - rootDir1.deleteOnExit() + rootDir0 = Utils.createTempDir() + rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index c3dd156b40514..dc2a05631d83d 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, IOException} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfter, FunSuite} @@ -44,7 +43,7 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { private var logDirPathString: String = _ before { - testDir = Files.createTempDir() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "test-file-logger") logDirPathString = logDirPath.toString } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index e63d9d085e385..0344da60dae66 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -112,7 +112,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes of a file") { - val tmpDir2 = Files.createTempDir() + val tmpDir2 = Utils.createTempDir() tmpDir2.deleteOnExit() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) @@ -141,7 +141,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes across multiple files") { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() tmpDir.deleteOnExit() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) Files.write("0123456789", files(0), Charsets.UTF_8) @@ -308,4 +308,28 @@ class UtilsSuite extends FunSuite { } } + test("deleteRecursively") { + val tempDir1 = Utils.createTempDir() + assert(tempDir1.exists()) + Utils.deleteRecursively(tempDir1) + assert(!tempDir1.exists()) + + val tempDir2 = Utils.createTempDir() + val tempFile1 = new File(tempDir2, "foo.txt") + Files.touch(tempFile1) + assert(tempFile1.exists()) + Utils.deleteRecursively(tempFile1) + assert(!tempFile1.exists()) + + val tempDir3 = new File(tempDir2, "subdir") + assert(tempDir3.mkdir()) + val tempFile2 = new File(tempDir3, "bar.txt") + Files.touch(tempFile2) + assert(tempFile2.exists()) + Utils.deleteRecursively(tempDir2) + assert(!tempDir2.exists()) + assert(!tempDir3.exists()) + assert(!tempFile2.exists()) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 8ef2bb1bf6a78..0dbe766b4d917 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -67,8 +67,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString @@ -100,7 +99,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "output") MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString) val lines = outputDir.listFiles() @@ -166,7 +165,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Vectors.sparse(2, Array(1), Array(-1.0)), Vectors.dense(0.0, 1.0) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "vectors") val path = outputDir.toURI.toString vectors.saveAsTextFile(path) @@ -181,7 +180,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))), LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "points") val path = outputDir.toURI.toString points.saveAsTextFile(path) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 3e2ee7541f40d..6a79e76a34db8 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -23,8 +23,6 @@ import java.net.{URL, URLClassLoader} import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import com.google.common.io.Files - import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.util.Utils @@ -39,10 +37,8 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def beforeAll() { super.beforeAll() - tempDir1 = Files.createTempDir() - tempDir1.deleteOnExit() - tempDir2 = Files.createTempDir() - tempDir2.deleteOnExit() + tempDir1 = Utils.createTempDir() + tempDir2 = Utils.createTempDir() url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c8763eb277052..91c9c52c3c98a 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,7 +22,6 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext import org.apache.commons.lang3.StringEscapeUtils @@ -190,8 +189,7 @@ class ReplSuite extends FunSuite { } test("interacting with files") { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val out = new FileWriter(tempDir + "/input") out.write("Hello world!\n") out.write("What's up?\n") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8511390cb1ad5..e5592e52b0d2d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -231,8 +231,7 @@ class CheckpointSuite extends TestSuiteBase { // failure, are re-processed or not. test("recovery with file input stream") { // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() var ssc = new StreamingContext(master, framework, Seconds(1)) ssc.checkpoint(checkpointDir) val fileStream = ssc.textFileStream(testDir.toString) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 952a74fd5f6de..a44a45a3e9bd6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -98,8 +98,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() val ssc = new StreamingContext(conf, batchDuration) val fileStream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index c53c01706083a..5dbb7232009eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -352,8 +352,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) extends Thread with Logging { override def run() { - val localTestDir = Files.createTempDir() - localTestDir.deleteOnExit() + val localTestDir = Utils.createTempDir() var fs = testDir.getFileSystem(new Configuration()) val maxTries = 3 try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 759baacaa4308..9327ff4822699 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, FunSuite} -import com.google.common.io.Files import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.util.ManualClock import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -120,9 +120,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Directory where the checkpoint data will be saved lazy val checkpointDir = { - val dir = Files.createTempDir() + val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") - dir.deleteOnExit() dir.toString } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 9bd916100dd2c..17b79ae1d82c4 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -20,13 +20,10 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ @@ -117,7 +114,7 @@ class ClientBaseSuite extends FunSuite with Matchers { doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), any(classOf[Path]), anyShort(), anyBoolean()) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { client.prepareLocalResources(tempDir.getAbsolutePath()) sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) From edf02da389f75df5a42465d41f035d6b65599848 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 9 Oct 2014 18:25:06 -0700 Subject: [PATCH 254/315] [SPARK-3654][SQL] Unifies SQL and HiveQL parsers This PR is a follow up of #2590, and tries to introduce a top level SQL parser entry point for all SQL dialects supported by Spark SQL. A top level parser `SparkSQLParser` is introduced to handle the syntaxes that all SQL dialects should recognize (e.g. `CACHE TABLE`, `UNCACHE TABLE` and `SET`, etc.). For all the syntaxes this parser doesn't recognize directly, it fallbacks to a specified function that tries to parse arbitrary input to a `LogicalPlan`. This function is typically another parser combinator like `SqlParser`. DDL syntaxes introduced in #2475 can be moved to here. The `ExtendedHiveQlParser` now only handle Hive specific extensions. Also took the chance to refactor/reformat `SqlParser` for better readability. Author: Cheng Lian Closes #2698 from liancheng/gen-sql-parser and squashes the following commits: ceada76 [Cheng Lian] Minor styling fixes 9738934 [Cheng Lian] Minor refactoring, removes optional trailing ";" in the parser bb2ab12 [Cheng Lian] SET property value can be empty string ce8860b [Cheng Lian] Passes test suites e86968e [Cheng Lian] Removes debugging code 8bcace5 [Cheng Lian] Replaces digit.+ to rep1(digit) (Scala style checking doesn't like it) d15d54f [Cheng Lian] Unifies SQL and HiveQL parsers --- .../spark/sql/catalyst/SparkSQLParser.scala | 186 ++++++++ .../apache/spark/sql/catalyst/SqlParser.scala | 426 +++++++----------- .../sql/catalyst/plans/logical/commands.scala | 15 +- .../org/apache/spark/sql/SQLContext.scala | 9 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../apache/spark/sql/execution/commands.scala | 34 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../server/SparkSQLOperationManager.scala | 2 +- .../spark/sql/hive/ExtendedHiveQlParser.scala | 110 +---- .../org/apache/spark/sql/hive/HiveQl.scala | 15 +- .../spark/sql/hive/HiveStrategies.scala | 8 +- 12 files changed, 414 insertions(+), 401 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala new file mode 100644 index 0000000000000..04467342e6ab5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.{PackratParsers, RegexParsers} +import scala.util.parsing.input.CharArrayReader.EofCh + +import org.apache.spark.sql.catalyst.plans.logical._ + +private[sql] abstract class AbstractSparkSQLParser + extends StandardTokenParsers with PackratParsers { + + def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } + + protected case class Keyword(str: String) + + protected def start: Parser[LogicalPlan] + + // Returns the whole input string + protected lazy val wholeInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success(in.source.toString, in.drop(in.source.length())) + } + + // Returns the rest of the input string that are not parsed yet + protected lazy val restInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success( + in.source.subSequence(in.offset, in.source.length()).toString, + in.drop(in.source.length())) + } +} + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", "." + ) + + override lazy val token: Parser[Token] = + ( identChar ~ (identChar | digit).* ^^ + { case first ~ rest => processIdent((first :: rest).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} + +/** + * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL + * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. + * + * @param fallback A function that parses an input string to a logical plan + */ +private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { + + // A parser for the key-value part of the "SET [key = [value ]]" syntax + private object SetCommandParser extends RegexParsers { + private val key: Parser[String] = "(?m)[^=]+".r + + private val value: Parser[String] = "(?m).*$".r + + private val pair: Parser[LogicalPlan] = + (key ~ ("=".r ~> value).?).? ^^ { + case None => SetCommand(None) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) + } + + def apply(input: String): LogicalPlan = parseAll(pair, input) match { + case Success(plan, _) => plan + case x => sys.error(x.toString) + } + } + + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val TABLE = Keyword("TABLE") + protected val SOURCE = Keyword("SOURCE") + protected val UNCACHE = Keyword("UNCACHE") + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + private val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + override protected lazy val start: Parser[LogicalPlan] = + cache | uncache | set | shell | source | others + + private lazy val cache: Parser[LogicalPlan] = + CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) + } + + private lazy val uncache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident ^^ { + case tableName => UncacheTableCommand(tableName) + } + + private lazy val set: Parser[LogicalPlan] = + SET ~> restInput ^^ { + case input => SetCommandParser(input) + } + + private lazy val shell: Parser[LogicalPlan] = + "!" ~> restInput ^^ { + case input => ShellCommand(input.trim) + } + + private lazy val source: Parser[LogicalPlan] = + SOURCE ~> restInput ^^ { + case input => SourceCommand(input.trim) + } + + private lazy val others: Parser[LogicalPlan] = + wholeInput ^^ { + case input => fallback(input) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 4662f585cfe15..b4d606d37e732 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -18,10 +18,6 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers -import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -39,31 +35,7 @@ import org.apache.spark.sql.catalyst.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - +class SqlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) @@ -100,7 +72,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") protected val LAST = Keyword("LAST") - protected val LAZY = Keyword("LAZY") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") @@ -128,7 +99,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val THEN = Keyword("THEN") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") - protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") protected val WHEN = Keyword("WHEN") @@ -136,7 +106,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { // Use reflection to find the reserved words defined in this class. protected val reservedWords = - this.getClass + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) .map(_.invoke(this).asInstanceOf[Keyword].str) @@ -150,86 +121,68 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } } - protected lazy val query: Parser[LogicalPlan] = ( - select * ( - UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | - INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } | - EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | - UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + protected lazy val start: Parser[LogicalPlan] = + ( select * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache | unCache - ) + | insert + ) protected lazy val select: Parser[LogicalPlan] = - SELECT ~> opt(DISTINCT) ~ projections ~ - opt(from) ~ opt(filter) ~ - opt(grouping) ~ - opt(having) ~ - opt(orderBy) ~ - opt(limit) <~ opt(";") ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) - val withFilter = f.map(f => Filter(f, base)).getOrElse(base) - val withProjection = - g.map {g => - Aggregate(g, assignAliases(p), withFilter) - }.getOrElse(Project(assignAliases(p), withFilter)) - val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) - val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) - val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) - val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder) - withLimit - } + SELECT ~> DISTINCT.? ~ + repsep(projection, ",") ~ + (FROM ~> relations).? ~ + (WHERE ~> expression).? ~ + (GROUP ~ BY ~> rep1sep(expression, ",")).? ~ + (HAVING ~> expression).? ~ + (ORDER ~ BY ~> ordering).? ~ + (LIMIT ~> expression).? ^^ { + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + val base = r.getOrElse(NoRelation) + val withFilter = f.map(f => Filter(f, base)).getOrElse(base) + val withProjection = g + .map(Aggregate(_, assignAliases(p), withFilter)) + .getOrElse(Project(assignAliases(p), withFilter)) + val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) + val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) + val withOrder = o.map(Sort(_, withHaving)).getOrElse(withHaving) + val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder) + withLimit + } protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ { - case o ~ r ~ s => - val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" - InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> select) <~ opt(";") ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val unCache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { - case tableName => UncacheTableCommand(tableName) + INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ { + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined) } - protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") - protected lazy val projection: Parser[Expression] = - expression ~ (opt(AS) ~> opt(ident)) ^^ { - case e ~ None => e - case e ~ Some(a) => Alias(e, a)() + expression ~ (AS.? ~> ident.?) ^^ { + case e ~ a => a.fold(e)(Alias(e, _)()) } - protected lazy val from: Parser[LogicalPlan] = FROM ~> relations - - protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation - // Based very loosely on the MySQL Grammar. // http://dev.mysql.com/doc/refman/5.0/en/join.html protected lazy val relations: Parser[LogicalPlan] = - relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } | - relation + ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) } + | relation + ) protected lazy val relation: Parser[LogicalPlan] = - joinedRelation | - relationFactor + joinedRelation | relationFactor protected lazy val relationFactor: Parser[LogicalPlan] = - ident ~ (opt(AS) ~> opt(ident)) ^^ { - case tableName ~ alias => UnresolvedRelation(None, tableName, alias) - } | - "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } + ( ident ~ (opt(AS) ~> opt(ident)) ^^ { + case tableName ~ alias => UnresolvedRelation(None, tableName, alias) + } + | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } + ) protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { - case r1 ~ jt ~ _ ~ r2 ~ cond => + relationFactor ~ joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.? ^^ { + case r1 ~ jt ~ r2 ~ cond => Join(r1, r2, joinType = jt.getOrElse(Inner), cond) } @@ -237,160 +190,145 @@ class SqlParser extends StandardTokenParsers with PackratParsers { ON ~> expression protected lazy val joinType: Parser[JoinType] = - INNER ^^^ Inner | - LEFT ~ SEMI ^^^ LeftSemi | - LEFT ~ opt(OUTER) ^^^ LeftOuter | - RIGHT ~ opt(OUTER) ^^^ RightOuter | - FULL ~ opt(OUTER) ^^^ FullOuter - - protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } - - protected lazy val orderBy: Parser[Seq[SortOrder]] = - ORDER ~> BY ~> ordering + ( INNER ^^^ Inner + | LEFT ~ SEMI ^^^ LeftSemi + | LEFT ~ OUTER.? ^^^ LeftOuter + | RIGHT ~ OUTER.? ^^^ RightOuter + | FULL ~ OUTER.? ^^^ FullOuter + ) protected lazy val ordering: Parser[Seq[SortOrder]] = - rep1sep(singleOrder, ",") | - rep1sep(expression, ",") ~ opt(direction) ^^ { - case exps ~ None => exps.map(SortOrder(_, Ascending)) - case exps ~ Some(d) => exps.map(SortOrder(_, d)) - } + ( rep1sep(singleOrder, ",") + | rep1sep(expression, ",") ~ direction.? ^^ { + case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending))) + } + ) protected lazy val singleOrder: Parser[SortOrder] = - expression ~ direction ^^ { case e ~ o => SortOrder(e,o) } + expression ~ direction ^^ { case e ~ o => SortOrder(e, o) } protected lazy val direction: Parser[SortDirection] = - ASC ^^^ Ascending | - DESC ^^^ Descending - - protected lazy val grouping: Parser[Seq[Expression]] = - GROUP ~> BY ~> rep1sep(expression, ",") - - protected lazy val having: Parser[Expression] = - HAVING ~> expression - - protected lazy val limit: Parser[Expression] = - LIMIT ~> expression + ( ASC ^^^ Ascending + | DESC ^^^ Descending + ) - protected lazy val expression: Parser[Expression] = orExpression + protected lazy val expression: Parser[Expression] = + orExpression protected lazy val orExpression: Parser[Expression] = - andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) }) + andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) protected lazy val andExpression: Parser[Expression] = - comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) + comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) protected lazy val comparisonExpression: Parser[Expression] = - termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } | - termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | - termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | - termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | - termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | - termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { - case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - } | - termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | - termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ e2 => In(e1, e2) - } | - termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2)) - } | - termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } | - termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } | - NOT ~> termExpression ^^ {e => Not(e)} | - termExpression + ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } + | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) } + | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) } + | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) } + | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } + | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { + case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + } + | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => In(e1, e2) + } + | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => Not(In(e1, e2)) + } + | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } + | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } + | NOT ~> termExpression ^^ {e => Not(e)} + | termExpression + ) protected lazy val termExpression: Parser[Expression] = - productExpression * ( - "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } | - "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } ) + productExpression * + ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) } + | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) } + ) protected lazy val productExpression: Parser[Expression] = - baseExpression * ( - "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } | - "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } | - "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } - ) + baseExpression * + ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } + | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } + | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } + ) protected lazy val function: Parser[Expression] = - SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } | - SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | - COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | - COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } | - COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | - APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { - case exp => ApproxCountDistinct(exp) - } | - APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { - case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) - } | - FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | - LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } | - AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | - MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | - MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } | - UPPER ~> "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } | - LOWER ~> "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } | - IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case c ~ "," ~ t ~ "," ~ f => If(c,t,f) - } | - CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ - (ELSE ~> expression).? <~ END ^^ { - case casePart ~ altPart ~ elsePart => - val altExprs = altPart.flatMap { - case we ~ te => - Seq(casePart.fold(we)(EqualTo(_, we)), te) + ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } + | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } + | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } + | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } + | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } + | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ + { case exp => ApproxCountDistinct(exp) } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } + | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } + | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } + | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } + | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } + | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } + | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } + | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } + | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case c ~ t ~ f => If(c, t, f) } + | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + (ELSE ~> expression).? <~ END ^^ { + case casePart ~ altPart ~ elsePart => + val altExprs = altPart.flatMap { case whenExpr ~ thenExpr => + Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr) + } + CaseWhen(altExprs ++ elsePart.toList) } - CaseWhen(altExprs ++ elsePart.toList) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) - } | - SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | - ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | - ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { - case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) - } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p ~ l => Substring(s, p, l) } + | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } + | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + ) protected lazy val cast: Parser[Expression] = - CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) } + CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } protected lazy val literal: Parser[Literal] = - numericLit ^^ { - case i if i.toLong > Int.MaxValue => Literal(i.toLong) - case i => Literal(i.toInt) - } | - NULL ^^^ Literal(null, NullType) | - floatLit ^^ {case f => Literal(f.toDouble) } | - stringLit ^^ {case s => Literal(s, StringType) } + ( numericLit ^^ { + case i if i.toLong > Int.MaxValue => Literal(i.toLong) + case i => Literal(i.toInt) + } + | NULL ^^^ Literal(null, NullType) + | floatLit ^^ {case f => Literal(f.toDouble) } + | stringLit ^^ {case s => Literal(s, StringType) } + ) protected lazy val floatLit: Parser[String] = elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { - case base ~ _ ~ ordinal => GetItem(base, ordinal) - } | - (expression <~ ".") ~ ident ^^ { - case base ~ fieldName => GetField(base, fieldName) - } | - TRUE ^^^ Literal(true, BooleanType) | - FALSE ^^^ Literal(false, BooleanType) | - cast | - "(" ~> expression <~ ")" | - function | - "-" ~> literal ^^ UnaryMinus | - dotExpressionHeader | - ident ^^ UnresolvedAttribute | - "*" ^^^ Star(None) | - literal + ( expression ~ ("[" ~> expression <~ "]") ^^ + { case base ~ ordinal => GetItem(base, ordinal) } + | (expression <~ ".") ~ ident ^^ + { case base ~ fieldName => GetField(base, fieldName) } + | TRUE ^^^ Literal(true, BooleanType) + | FALSE ^^^ Literal(false, BooleanType) + | cast + | "(" ~> expression <~ ")" + | function + | "-" ~> literal ^^ UnaryMinus + | dotExpressionHeader + | ident ^^ UnresolvedAttribute + | "*" ^^^ Star(None) + | literal + ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { @@ -400,55 +338,3 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType } - -class SqlLexical(val keywords: Seq[String]) extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - - reserved ++= keywords.flatMap(w => allCaseVersions(w)) - - delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", "." - ) - - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('_') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 9a3848cfc6b62..b8ba2ee428a20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -39,9 +39,9 @@ case class NativeCommand(cmd: String) extends Command { } /** - * Commands of the form "SET (key) (= value)". + * Commands of the form "SET [key [= value] ]". */ -case class SetCommand(key: Option[String], value: Option[String]) extends Command { +case class SetCommand(kv: Option[(String, Option[String])]) extends Command { override def output = Seq( AttributeReference("", StringType, nullable = false)()) } @@ -81,3 +81,14 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } + +/** + * Returned for the "! shellCommand" command + */ +case class ShellCommand(cmd: String) extends Command + + +/** + * Returned for the "SOURCE file" command + */ +case class SourceCommand(filePath: String) extends Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 014e1e2826724..23e7b2d270777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -66,12 +66,17 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = true) + @transient protected[sql] val optimizer = Optimizer + @transient - protected[sql] val parser = new catalyst.SqlParser + protected[sql] val sqlParser = { + val fallback = new catalyst.SqlParser + new catalyst.SparkSQLParser(fallback(_)) + } - protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bbf17b9fadf86..4f1af7234d551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -304,8 +304,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.SetCommand(key, value) => - Seq(execution.SetCommand(key, value, plan.output)(context)) + case logical.SetCommand(kv) => + Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheTableCommand(tableName, optPlan, isLazy) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index d49633c24ad4d..5859eba408ee1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -48,29 +48,28 @@ trait Command { * :: DeveloperApi :: */ @DeveloperApi -case class SetCommand( - key: Option[String], value: Option[String], output: Seq[Attribute])( +case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected lazy val sideEffectResult: Seq[Row] = (key, value) match { - // Set value for key k. - case (Some(k), Some(v)) => - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + override protected lazy val sideEffectResult: Seq[Row] = kv match { + // Set value for the key. + case Some((key, Some(value))) => + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")) + context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) } else { - context.setConf(k, v) - Seq(Row(s"$k=$v")) + context.setConf(key, value) + Seq(Row(s"$key=$value")) } - // Query the value bound to key k. - case (Some(k), _) => + // Query the value bound to the key. + case Some((key, None)) => // TODO (lian) This is just a workaround to make the Simba ODBC driver work. // Should remove this once we get the ODBC driver updated. - if (k == "-v") { + if (key == "-v") { val hiveJars = Seq( "hive-exec-0.12.0.jar", "hive-service-0.12.0.jar", @@ -84,23 +83,20 @@ case class SetCommand( Row("system:java.class.path=" + hiveJars), Row("system:sun.java.command=shark.SharkServer2")) } else { - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) } else { - Seq(Row(s"$k=${context.getConf(k, "")}")) + Seq(Row(s"$key=${context.getConf(key, "")}")) } } // Query all key-value pairs that are set in the SQLConf of the context. - case (None, None) => + case _ => context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq - - case _ => - throw new IllegalArgumentException() } override def otherCopyArgs = context :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e624f97004f5..c87ded81fdc27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -69,7 +69,7 @@ class CachedTableSuite extends QueryTest { test("calling .unpersist() should drop in-memory columnar cache") { table("testData").cache() table("testData").count() - table("testData").unpersist(true) + table("testData").unpersist(blocking = true) assertCached(table("testData"), 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 79de1bb855dbe..a94022c0cf6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -42,7 +42,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { TimeZone.setDefault(origZone) } - test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), @@ -61,7 +60,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 4) } - test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), @@ -694,6 +692,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { checkAnswer( - sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 910174a153768..accf61576b804 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -172,7 +172,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => sessionToActivePool(parentSession) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index c5844e92eaaa9..430ffb29989ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -18,118 +18,50 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers + import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical} /** - * A parser that recognizes all HiveQL constructs together with several Spark SQL specific - * extensions like CACHE TABLE and UNCACHE TABLE. + * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ -private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else if (input.trim.startsWith("!")) { - ShellCommand(input.drop(1)) - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - - protected val ADD = Keyword("ADD") - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val DFS = Keyword("DFS") - protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SOURCE = Keyword("SOURCE") - protected val TABLE = Keyword("TABLE") - protected val UNCACHE = Keyword("UNCACHE") - +private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - protected def allCaseConverse(k: String): Parser[String] = - lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _) + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") + protected val FILE = Keyword("FILE") + protected val JAR = Keyword("JAR") - protected val reservedWords = - this.getClass + private val reservedWords = + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) .map(_.invoke(this).asInstanceOf[Keyword].str) override val lexical = new SqlLexical(reservedWords) - protected lazy val query: Parser[LogicalPlan] = - cache | uncache | addJar | addFile | dfs | source | hiveQl + protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.createPlan(statement.trim()) - } - - // Returns the whole input string - protected lazy val wholeInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success(in.source.toString, in.drop(in.source.length())) - } - - // Returns the rest of the input string that are not parsed yet - protected lazy val restInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success( - in.source.subSequence(in.offset, in.source.length).toString, - in.drop(in.source.length())) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> hiveQl) ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val uncache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) + case statement => HiveQl.createPlan(statement.trim) } - protected lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> restInput ^^ { - case jar => AddJar(jar.trim()) + protected lazy val dfs: Parser[LogicalPlan] = + DFS ~> wholeInput ^^ { + case command => NativeCommand(command.trim) } - protected lazy val addFile: Parser[LogicalPlan] = + private lazy val addFile: Parser[LogicalPlan] = ADD ~ FILE ~> restInput ^^ { - case file => AddFile(file.trim()) + case input => AddFile(input.trim) } - protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> wholeInput ^^ { - case command => NativeCommand(command.trim()) - } - - protected lazy val source: Parser[LogicalPlan] = - SOURCE ~> restInput ^^ { - case file => SourceCommand(file.trim()) + private lazy val addJar: Parser[LogicalPlan] = + ADD ~ JAR ~> restInput ^^ { + case input => AddJar(input.trim) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 32c9175f181bb..98a46a31e1ffd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.spark.sql.catalyst.SparkSQLParser import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -38,10 +39,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -private[hive] case class ShellCommand(cmd: String) extends Command - -private[hive] case class SourceCommand(filePath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command private[hive] case class AddJar(path: String) extends Command @@ -126,9 +123,11 @@ private[hive] object HiveQl { "TOK_CREATETABLE", "TOK_DESCTABLE" ) ++ nativeCommands - - // It parses hive sql query along with with several Spark SQL specific extensions - protected val hiveSqlParser = new ExtendedHiveQlParser + + protected val hqlParser = { + val fallback = new ExtendedHiveQlParser + new SparkSQLParser(fallback(_)) + } /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations @@ -218,7 +217,7 @@ private[hive] object HiveQl { def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql) + def parseSql(sql: String): LogicalPlan = hqlParser(sql) /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 508d8239c7628..5c66322f1ed99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -167,10 +167,10 @@ private[hive] trait HiveStrategies { database.get, tableName, query, - InsertIntoHiveTable(_: MetastoreRelation, - Map(), - query, - true)(hiveContext)) :: Nil + InsertIntoHiveTable(_: MetastoreRelation, + Map(), + query, + overwrite = true)(hiveContext)) :: Nil case _ => Nil } } From 421382d0e728940caa3e61bc11237c61f256378a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 9 Oct 2014 18:26:43 -0700 Subject: [PATCH 255/315] [SPARK-3824][SQL] Sets in-memory table default storage level to MEMORY_AND_DISK Using `MEMORY_AND_DISK` as default storage level for in-memory table caching. Due to the in-memory columnar representation, recomputing an in-memory cached table partitions can be very expensive. Author: Cheng Lian Closes #2686 from liancheng/spark-3824 and squashes the following commits: 35d2ed0 [Cheng Lian] Removes extra space 1ab7967 [Cheng Lian] Reduces test data size to fit DiskStore.getBytes() ba565f0 [Cheng Lian] Maks CachedBatch serializable 07f0204 [Cheng Lian] Sets in-memory table default storage level to MEMORY_AND_DISK --- .../main/scala/org/apache/spark/sql/CacheManager.scala | 10 +++++++--- .../spark/sql/columnar/InMemoryColumnarTableScan.scala | 9 +++++---- .../scala/org/apache/spark/sql/CachedTableSuite.scala | 10 +++++----- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3bf7382ac67a6..5ab2b5316ab10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel -import org.apache.spark.storage.StorageLevel.MEMORY_ONLY +import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) @@ -74,10 +74,14 @@ private[sql] trait CacheManager { cachedData.clear() } - /** Caches the data produced by the logical representation of the given schema rdd. */ + /** + * Caches the data produced by the logical representation of the given schema rdd. Unlike + * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing + * the in-memory columnar representation of the underlying table is expensive. + */ private[sql] def cacheQuery( query: SchemaRDD, - storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.optimizedPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 4f79173a26f88..22ab0e2613f21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -38,7 +38,7 @@ private[sql] object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } -private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) +private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -91,7 +91,7 @@ private[sql] case class InMemoryRelation( val stats = Row.fromSeq( columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) - CachedBatch(columnBuilders.map(_.build()), stats) + CachedBatch(columnBuilders.map(_.build().array()), stats) } def hasNext = rowIterator.hasNext @@ -238,8 +238,9 @@ private[sql] case class InMemoryColumnarTableScan( def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors - val columnAccessors = - requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + val columnAccessors = requestedColumnIndices.map { batch => + ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch))) + } // Extract rows via column accessors new Iterator[Row] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c87ded81fdc27..444bc95009c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.storage.RDDBlockId +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -55,10 +55,10 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") - cacheTable("bigData") - assert(table("bigData").count() === 1000000L) - uncacheTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(table("bigData").count() === 200000L) + table("bigData").unpersist() } test("calling .cache() should use in-memory columnar caching") { From 6f98902a3d7749e543bc493a8c62b1e3a7b924cc Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 9 Oct 2014 18:41:36 -0700 Subject: [PATCH 256/315] [SPARK-3834][SQL] Backticks not correctly handled in subquery aliases The queries like SELECT a.key FROM (SELECT key FROM src) \`a\` does not work as backticks in subquery aliases are not handled properly. This PR fixes that. Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala Closes #2737 from ravipesala/SPARK-3834 and squashes the following commits: 0e0ab98 [ravipesala] Fixing issue in backtick handling for subquery aliases --- .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../apache/spark/sql/hive/execution/SQLQuerySuite.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 98a46a31e1ffd..7cc14dc7a9c9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -638,7 +638,7 @@ private[hive] object HiveQl { def nodeToRelation(node: Node): LogicalPlan = node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - Subquery(alias, nodeToPlan(query)) + Subquery(cleanIdentifier(alias), nodeToPlan(query)) case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => val Token("TOK_SELECT", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3647bb1c4ce7d..fbe6ac765c009 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -68,5 +68,11 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), sql("SELECT `key` FROM src").collect().toSeq) - } + } + + test("SPARK-3834 Backticks not correctly handled in subquery aliases") { + checkAnswer( + sql("SELECT a.key FROM (SELECT key FROM src) `a`"), + sql("SELECT `key` FROM src").collect().toSeq) + } } From 411cf29fff011561f0093bb6101af87842828369 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Fri, 10 Oct 2014 00:46:56 -0700 Subject: [PATCH 257/315] [SPARK-2805] Upgrade Akka to 2.3.4 This is a second rev of the Akka upgrade (earlier merged, but reverted). I made a slight modification which is that I also upgrade Hive to deal with a compatibility issue related to the protocol buffers library. Author: Anand Avati Author: Patrick Wendell Closes #2752 from pwendell/akka-upgrade and squashes the following commits: 4c7ca3f [Patrick Wendell] Upgrading to new hive->protobuf version 57a2315 [Anand Avati] SPARK-1812: streaming - remove tests which depend on akka.actor.IO 2a551d3 [Anand Avati] SPARK-1812: core - upgrade to akka 2.3.4 --- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- pom.xml | 4 +- .../spark/streaming/InputStreamsSuite.scala | 71 ------------------- 6 files changed, 7 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/pom.xml b/pom.xml index 7756c89b00cad..d047b9e307d4b 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 1.0.4 @@ -127,7 +127,7 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0 + 0.12.0-protobuf 1.4.3 1.2.3 8.1.14.v20131031 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index a44a45a3e9bd6..fa04fa326e370 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -143,59 +141,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -377,22 +322,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { From 90f73fcc47c7bf881f808653d46a9936f37c3c31 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 10 Oct 2014 01:44:36 -0700 Subject: [PATCH 258/315] [SPARK-3889] Attempt to avoid SIGBUS by not mmapping files in ConnectionManager In general, individual shuffle blocks are frequently small, so mmapping them often creates a lot of waste. It may not be bad to mmap the larger ones, but it is pretty inconvenient to get configuration into ManagedBuffer, and besides it is unlikely to help all that much. Author: Aaron Davidson Closes #2742 from aarondav/mmap and squashes the following commits: a152065 [Aaron Davidson] Add other pathway back 52b6cd2 [Aaron Davidson] [SPARK-3889] Attempt to avoid SIGBUS by not mmapping files in ConnectionManager --- .../org/apache/spark/network/ManagedBuffer.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index a4409181ec907..4c9ca97a2a6b7 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -66,13 +66,27 @@ sealed abstract class ManagedBuffer { final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + override def size: Long = length override def nioByteBuffer(): ByteBuffer = { var channel: FileChannel = null try { channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + val buf = ByteBuffer.allocate(length.toInt) + channel.read(buf, offset) + buf.flip() + buf + } else { + channel.map(MapMode.READ_ONLY, offset, length) + } } catch { case e: IOException => Try(channel.size).toOption match { From 72f36ee571ad27c7c7c70bb9aecc7e6ef51dfd44 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 10 Oct 2014 14:14:05 -0700 Subject: [PATCH 259/315] [SPARK-3886] [PySpark] use AutoBatchedSerializer by default Use AutoBatchedSerializer by default, which will choose the proper batch size based on size of serialized objects, let the size of serialized batch fall in into [64k - 640k]. In JVM, the serializer will also track the objects in batch to figure out duplicated objects, larger batch may cause OOM in JVM. Author: Davies Liu Closes #2740 from davies/batchsize and squashes the following commits: 52cdb88 [Davies Liu] update docs 185f2b9 [Davies Liu] use AutoBatchedSerializer by default --- python/pyspark/context.py | 11 +++++++---- python/pyspark/serializers.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6fb30d65c5edd..85c04624da4a6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -67,7 +67,7 @@ class SparkContext(object): _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, + environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, gateway=None): """ Create a new SparkContext. At least the master and app name should be set, @@ -83,8 +83,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param environment: A dictionary of environment variables to set on worker nodes. :param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. + Java object. Set 1 to disable batching, 0 to automatically choose + the batch size based on object sizes, or -1 to use an unlimited + batch size :param serializer: The serializer for RDDs. :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM @@ -117,6 +118,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._unbatched_serializer = serializer if batchSize == 1: self.serializer = self._unbatched_serializer + elif batchSize == 0: + self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 099fa54cf2bd7..3d1a34b281acc 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -220,7 +220,7 @@ class AutoBatchedSerializer(BatchedSerializer): Choose the size of batch automatically based on the size of object """ - def __init__(self, serializer, bestSize=1 << 20): + def __init__(self, serializer, bestSize=1 << 16): BatchedSerializer.__init__(self, serializer, -1) self.bestSize = bestSize @@ -247,7 +247,7 @@ def __eq__(self, other): other.serializer == self.serializer) def __str__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer<%s>" % str(self.serializer) class CartesianDeserializer(FramedSerializer): From 1d72a30874a88bdbab75217f001cf2af409016e7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 10 Oct 2014 16:49:19 -0700 Subject: [PATCH 260/315] HOTFIX: Fix build issue with Akka 2.3.4 upgrade. We had to upgrade our Hive 0.12 version as well to deal with a protobuf conflict (both hive and akka have been using a shaded protobuf version). This is testing a correctly patched version of Hive 0.12. Author: Patrick Wendell Closes #2756 from pwendell/hotfix and squashes the following commits: cc979d0 [Patrick Wendell] HOTFIX: Fix build issue with Akka 2.3.4 upgrade. --- pom.xml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index d047b9e307d4b..288bbf1114bea 100644 --- a/pom.xml +++ b/pom.xml @@ -127,7 +127,7 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0-protobuf + 0.12.0-protobuf-2.5 1.4.3 1.2.3 8.1.14.v20131031 @@ -223,6 +223,18 @@ false + + + spark-staging + Spring Staging Repository + https://oss.sonatype.org/content/repositories/orgspark-project-1085 + + true + + + false + + From 0e8203f4fb721158fb27897680da476174d24c4b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 10 Oct 2014 18:39:55 -0700 Subject: [PATCH 261/315] [SPARK-2924] Required by scala 2.11, only one fun/ctor amongst overriden alternatives, can have default argument(s). ...riden alternatives, can have default argument. Author: Prashant Sharma Closes #2750 from ScrapCodes/SPARK-2924/default-args-removed and squashes the following commits: d9785c3 [Prashant Sharma] [SPARK-2924] Required by scala 2.11, only one function/ctor amongst overriden alternatives, can have default argument. --- .../org/apache/spark/util/FileLogger.scala | 19 +++++++++++++++++-- .../apache/spark/util/FileLoggerSuite.scala | 8 ++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 6d1fc05a15d2c..fdc73f08261a6 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -51,12 +51,27 @@ private[spark] class FileLogger( def this( logDir: String, sparkConf: SparkConf, - compress: Boolean = false, - overwrite: Boolean = true) = { + compress: Boolean, + overwrite: Boolean) = { this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, overwrite = overwrite) } + def this( + logDir: String, + sparkConf: SparkConf, + compress: Boolean) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, + overwrite = true) + } + + def this( + logDir: String, + sparkConf: SparkConf) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = false, + overwrite = true) + } + private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index dc2a05631d83d..72466a3aa1130 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -74,13 +74,13 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { test("Logging when directory already exists") { // Create the logging directory multiple times - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() // If overwrite is not enabled, an exception should be thrown intercept[IOException] { - new FileLogger(logDirPathString, new SparkConf, overwrite = false).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = false).start() } } From 81015a2ba49583d730ce65b2262f50f1f2451a79 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Sat, 11 Oct 2014 11:26:17 -0700 Subject: [PATCH 262/315] [SPARK-3867][PySpark] ./python/run-tests failed when it run with Python 2.6 and unittest2 is not installed ./python/run-tests search a Python 2.6 executable on PATH and use it if available. When using Python 2.6, it is going to import unittest2 module which is not a standard library in Python 2.6, so it fails with ImportError. Author: cocoatomo Closes #2759 from cocoatomo/issues/3867-unittest2-import-error and squashes the following commits: f068eb5 [cocoatomo] [SPARK-3867] ./python/run-tests failed when it run with Python 2.6 and unittest2 is not installed --- python/pyspark/mllib/tests.py | 6 +++++- python/pyspark/tests.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5c20e100e144f..463faf7b6f520 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -25,7 +25,11 @@ from numpy import array, array_equal if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7f05d48ade2b3..ceab57464f013 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,7 +34,11 @@ from platform import python_implementation if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest From 7a3f589ef86200f99624fea8322e5af0cad774a7 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Sat, 11 Oct 2014 11:51:59 -0700 Subject: [PATCH 263/315] [SPARK-3909][PySpark][Doc] A corrupted format in Sphinx documents and building warnings Sphinx documents contains a corrupted ReST format and have some warnings. The purpose of this issue is same as https://issues.apache.org/jira/browse/SPARK-3773. commit: 0e8203f4fb721158fb27897680da476174d24c4b output ``` $ cd ./python/docs $ make clean html rm -rf _build/* sphinx-build -b html -d _build/doctrees . _build/html Making output directory... Running Sphinx v1.2.3 loading pickled environment... not yet created building [html]: targets for 4 source files that are out of date updating environment: 4 added, 0 changed, 0 removed reading sources... [100%] pyspark.sql /Users//MyRepos/Scala/spark/python/pyspark/mllib/feature.py:docstring of pyspark.mllib.feature.Word2VecModel.findSynonyms:4: WARNING: Field list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/mllib/feature.py:docstring of pyspark.mllib.feature.Word2VecModel.transform:3: WARNING: Field list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/sql.py:docstring of pyspark.sql:4: WARNING: Bullet list ends without a blank line; unexpected unindent. looking for now-outdated files... none found pickling environment... done checking consistency... done preparing documents... done writing output... [100%] pyspark.sql writing additional files... (12 module code pages) _modules/index search copying static files... WARNING: html_static_path entry u'/Users//MyRepos/Scala/spark/python/docs/_static' does not exist done copying extra files... done dumping search index... done dumping object inventory... done build succeeded, 4 warnings. Build finished. The HTML pages are in _build/html. ``` Author: cocoatomo Closes #2766 from cocoatomo/issues/3909-sphinx-build-warnings and squashes the following commits: 2c7faa8 [cocoatomo] [SPARK-3909][PySpark][Doc] A corrupted format in Sphinx documents and building warnings --- python/docs/conf.py | 2 +- python/pyspark/mllib/feature.py | 2 ++ python/pyspark/rdd.py | 2 +- python/pyspark/sql.py | 10 +++++----- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/docs/conf.py b/python/docs/conf.py index 8e6324f058251..e58d97ae6a746 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -131,7 +131,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +#html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index a44a27fd3b6a6..f4cbf31b94fe2 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -44,6 +44,7 @@ def transform(self, word): """ :param word: a word :return: vector representation of word + Transforms a word to its vector representation Note: local use only @@ -57,6 +58,7 @@ def findSynonyms(self, x, num): :param x: a word or a vector representation of word :param num: number of synonyms to find :return: array of (word, cosineSimilarity) + Find synonyms of a word Note: local use only diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6797d50659a92..e13bab946c44a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2009,7 +2009,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - :param relativeSD Relative accuracy. Smaller values create + :param relativeSD: Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d3d36eb995ab6..b31a82f9b19ac 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -19,14 +19,14 @@ public classes of Spark SQL: - L{SQLContext} - Main entry point for SQL functionality. + Main entry point for SQL functionality. - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. - L{Row} - A Row of data returned by a Spark SQL query. + A Row of data returned by a Spark SQL query. - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. + Main entry point for accessing data stored in Apache Hive.. """ import itertools From 69c67abaa9d4bb4b95792d1862bc65efc764c194 Mon Sep 17 00:00:00 2001 From: giwa Date: Sun, 12 Oct 2014 02:46:56 -0700 Subject: [PATCH 264/315] [SPARK-2377] Python API for Streaming This patch brings Python API for Streaming. This patch is based on work from @giwa Author: giwa Author: Ken Takagiwa Author: Davies Liu Author: Ken Takagiwa Author: Tathagata Das Author: Ken Author: Ken Takagiwa Author: Matthew Farrellee Closes #2538 from davies/streaming and squashes the following commits: 64561e4 [Davies Liu] fix tests 331ecce [Davies Liu] fix example 3e2492b [Davies Liu] change updateStateByKey() to easy API 182be73 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 02d0575 [Davies Liu] add wrapper for foreachRDD() bebeb4a [Davies Liu] address all comments 6db00da [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 8380064 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 52c535b [Davies Liu] remove fix for sum() e108ec1 [Davies Liu] address comments 37fe06f [Davies Liu] use random port for callback server d05871e [Davies Liu] remove reuse of PythonRDD be5e5ff [Davies Liu] merge branch of env, make tests stable. 8071541 [Davies Liu] Merge branch 'env' into streaming c7bbbce [Davies Liu] fix sphinx docs 6bb9d91 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 4d0ea8b [Davies Liu] clear reference of SparkEnv after stop 54bd92b [Davies Liu] improve tests c2b31cb [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming 7a88f9f [Davies Liu] rollback RDD.setContext(), use textFileStream() to test checkpointing bd8a4c2 [Davies Liu] fix scala style 7797c70 [Davies Liu] refactor ff88bec [Davies Liu] rename RDDFunction to TransformFunction d328aca [Davies Liu] fix serializer in queueStream 6f0da2f [Davies Liu] recover from checkpoint fa7261b [Davies Liu] refactor a13ff34 [Davies Liu] address comments 8466916 [Davies Liu] support checkpoint 9a16bd1 [Davies Liu] change number of partitions during tests b98d63f [Davies Liu] change private[spark] to private[python] eed6e2a [Davies Liu] rollback not needed changes e00136b [Davies Liu] address comments 069a94c [Davies Liu] fix the number of partitions during window() 338580a [Davies Liu] change _first(), _take(), _collect() as private API 19797f9 [Davies Liu] clean up 6ebceca [Davies Liu] add more tests c40c52d [Davies Liu] change first(), take(n) to has the same behavior as RDD 98ac6c2 [Davies Liu] support ssc.transform() b983f0f [Davies Liu] address comments 847f9b9 [Davies Liu] add more docs, add first(), take() e059ca2 [Davies Liu] move check of window into Python fce0ef5 [Davies Liu] rafactor of foreachRDD() 7001b51 [Davies Liu] refactor of queueStream() 26ea396 [Davies Liu] refactor 74df565 [Davies Liu] fix print and docs b32774c [Davies Liu] move java_import into streaming 604323f [Davies Liu] enable streaming tests c499ba0 [Davies Liu] remove Time and Duration 3f0fb4b [Davies Liu] refactor fix tests c28f520 [Davies Liu] support updateStateByKey d357b70 [Davies Liu] support windowed dstream bd13026 [Davies Liu] fix examples eec401e [Davies Liu] refactor, combine TransformedRDD, fix reuse PythonRDD, fix union 9a57685 [Davies Liu] fix python style bd27874 [Davies Liu] fix scala style 7339be0 [Davies Liu] delete tests 7f53086 [Davies Liu] support transform(), refactor and cleanup df098fc [Davies Liu] Merge branch 'master' into giwa 550dfd9 [giwa] WIP fixing 1.1 merge 5cdb6fa [giwa] changed for SCCallSiteSync e685853 [giwa] meged with rebased 1.1 branch 2d32a74 [giwa] added some StreamingContextTestSuite 4a59e1e [giwa] WIP:added more test for StreamingContext 8ffdbf1 [giwa] added atexit to handle callback server d5f5fcb [giwa] added comment for StreamingContext.sparkContext 63c881a [giwa] added StreamingContext.sparkContext d39f102 [giwa] added StreamingContext.remember d542743 [giwa] clean up code 2fdf0de [Matthew Farrellee] Fix scalastyle errors c0a06bc [giwa] delete not implemented functions f385976 [giwa] delete inproper comments b0f2015 [giwa] added comment in dstream._test_output bebb3f3 [giwa] remove the last brank line fbed8da [giwa] revert pom.xml 8ed93af [giwa] fixed explanaiton 066ba90 [giwa] revert pom.xml fa4af88 [giwa] remove duplicated import 6ae3caa [giwa] revert pom.xml 7dc7391 [giwa] fixed typo 62dc7a3 [giwa] clean up exmples f04882c [giwa] clen up examples b171ec3 [giwa] fixed pep8 violation f198d14 [giwa] clean up code 3166d31 [giwa] clean up c00e091 [giwa] change test case not to use awaitTermination e80647e [giwa] adopted the latest compression way of python command 58e41ff [giwa] merge with master 455e5af [giwa] removed wasted print in DStream af336b7 [giwa] add comments ddd4ee1 [giwa] added TODO coments 99ce042 [giwa] added saveAsTextFiles and saveAsPickledFiles 2a06cdb [giwa] remove waste duplicated code c5ecfc1 [giwa] basic function test cases are passed 8dcda84 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 795b2cd [giwa] broke something 1e126bf [giwa] WIP: solved partitioned and None is not recognized f67cf57 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test 953deb0 [giwa] edited the comment to add more precise description af610d3 [giwa] removed unnesessary changes c1d546e [giwa] fixed PEP-008 violation 99410be [giwa] delete waste file b3b0362 [giwa] added basic operation test cases 9cde7c9 [giwa] WIP added test case bd3ba53 [giwa] WIP 5c04a5f [giwa] WIP: added PythonTestInputStream 019ef38 [giwa] WIP 1934726 [giwa] update comment 376e3ac [giwa] WIP 932372a [giwa] clean up dstream.py 0b09cff [giwa] added stop in StreamingContext 92e333e [giwa] implemented reduce and count function in Dstream 1b83354 [giwa] Removed the waste line 88f7506 [Ken Takagiwa] Kill py4j callback server properly 54b5358 [Ken Takagiwa] tried to restart callback server 4f07163 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. fe02547 [Ken Takagiwa] remove waste file 2ad7bd3 [Ken Takagiwa] clean up codes 6197a11 [Ken Takagiwa] clean up code eb4bf48 [Ken Takagiwa] fix map function 98c2a00 [Ken Takagiwa] added count operation but this implementation need double check 58591d2 [Ken Takagiwa] reduceByKey is working 0df7111 [Ken Takagiwa] delete old file f485b1d [Ken Takagiwa] fied input of socketTextDStream dd6de81 [Ken Takagiwa] initial commit for socketTextStream 247fd74 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 4bcb318 [Ken Takagiwa] implementing transform function in Python 38adf95 [Ken Takagiwa] added reducedByKey not working yet 66fcfff [Ken Takagiwa] modify dstream.py to fix indent error 41886c2 [Ken Takagiwa] comment PythonDStream.PairwiseDStream 0b99bec [Ken] initial commit for pySparkStreaming c214199 [giwa] added testcase for combineByKey 5625bdc [giwa] added gorupByKey testcase 10ab87b [giwa] added sparkContext as input parameter in StreamingContext 10b5b04 [giwa] removed wasted print in DStream e54f986 [giwa] add comments 16aa64f [giwa] added TODO coments 74535d4 [giwa] added saveAsTextFiles and saveAsPickledFiles f76c182 [giwa] remove waste duplicated code 18c8723 [giwa] modified streaming test case to add coment 13fb44c [giwa] basic function test cases are passed 3000b2b [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 ff14070 [giwa] broke something bcdec33 [giwa] WIP: solved partitioned and None is not recognized 270a9e1 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test bb10956 [giwa] edited the comment to add more precise description 253a863 [giwa] removed unnesessary changes 3d37822 [giwa] fixed PEP-008 violation f21cab3 [giwa] delete waste file 878bad7 [giwa] added basic operation test cases ce2acd2 [giwa] WIP added test case 9ad6855 [giwa] WIP 1df77f5 [giwa] WIP: added PythonTestInputStream 1523b66 [giwa] WIP 8a0fbbc [giwa] update comment fe648e3 [giwa] WIP 29c2bc5 [giwa] initial commit for testcase 4d40d63 [giwa] clean up dstream.py c462bb3 [giwa] added stop in StreamingContext d2c01ba [giwa] clean up examples 3c45cd2 [giwa] implemented reduce and count function in Dstream b349649 [giwa] Removed the waste line 3b498e1 [Ken Takagiwa] Kill py4j callback server properly 84a9668 [Ken Takagiwa] tried to restart callback server 9ab8952 [Tathagata Das] Added extra line. 05e991b [Tathagata Das] Added missing file b1d2a30 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. 678e854 [Ken Takagiwa] remove waste file 0a8bbbb [Ken Takagiwa] clean up codes bab31c1 [Ken Takagiwa] clean up code 72b9738 [Ken Takagiwa] fix map function d3ee86a [Ken Takagiwa] added count operation but this implementation need double check 15feea9 [Ken Takagiwa] edit python sparkstreaming example 6f98e50 [Ken Takagiwa] reduceByKey is working c455c8d [Ken Takagiwa] added reducedByKey not working yet dc6995d [Ken Takagiwa] delete old file b31446a [Ken Takagiwa] fixed typo of network_workdcount.py ccfd214 [Ken Takagiwa] added doctest for pyspark.streaming.duration 0d1b954 [Ken Takagiwa] fied input of socketTextDStream f746109 [Ken Takagiwa] initial commit for socketTextStream bb7ccf3 [Ken Takagiwa] remove unused import in python 224fc5e [Ken Takagiwa] add empty line d2099d8 [Ken Takagiwa] sorted the import following Spark coding convention 5bac7ec [Ken Takagiwa] revert streaming/pom.xml e1df940 [Ken Takagiwa] revert pom.xml 494cae5 [Ken Takagiwa] remove not implemented DStream functions in python 17a74c6 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 1a0f065 [Ken Takagiwa] implementing transform function in Python d7b4d6f [Ken Takagiwa] added reducedByKey not working yet 87438e2 [Ken Takagiwa] modify dstream.py to fix indent error b406252 [Ken Takagiwa] comment PythonDStream.PairwiseDStream 454981d [Ken] initial commit for pySparkStreaming 150b94c [giwa] added some StreamingContextTestSuite f7bc8f9 [giwa] WIP:added more test for StreamingContext ee50c5a [giwa] added atexit to handle callback server fdc9125 [giwa] added comment for StreamingContext.sparkContext f5bfb70 [giwa] added StreamingContext.sparkContext da09768 [giwa] added StreamingContext.remember d68b568 [giwa] clean up code 4afa390 [giwa] clean up code 1fd6bc7 [Ken Takagiwa] Merge pull request #2 from mattf/giwa-master d9d59fe [Matthew Farrellee] Fix scalastyle errors 67473a9 [giwa] delete not implemented functions c97377c [giwa] delete inproper comments 2ea769e [giwa] added comment in dstream._test_output 3b27bd4 [giwa] remove the last brank line acfcaeb [giwa] revert pom.xml 93f7637 [giwa] fixed explanaiton 50fd6f9 [giwa] revert pom.xml 4f82c89 [giwa] remove duplicated import 9d1de23 [giwa] revert pom.xml 7339df2 [giwa] fixed typo 9c85e48 [giwa] clean up exmples 24f95db [giwa] clen up examples 0d30109 [giwa] fixed pep8 violation b7dab85 [giwa] improve test case 583e66d [giwa] move tests for streaming inside streaming directory 1d84142 [giwa] remove unimplement test f0ea311 [giwa] clean up code 171edeb [giwa] clean up 4dedd2d [giwa] change test case not to use awaitTermination 268a6a5 [giwa] Changed awaitTermination not to call awaitTermincation in Scala. Just use time.sleep instread 09a28bf [giwa] improve testcases 58150f5 [giwa] Changed the test case to focus the test operation 199e37f [giwa] adopted the latest compression way of python command 185fdbf [giwa] merge with master f1798c4 [giwa] merge with master e70f706 [giwa] added testcase for combineByKey e162822 [giwa] added gorupByKey testcase 97742fe [giwa] added sparkContext as input parameter in StreamingContext 14d4c0e [giwa] removed wasted print in DStream 6d8190a [giwa] add comments 4aa99e4 [giwa] added TODO coments e9fab72 [giwa] added saveAsTextFiles and saveAsPickledFiles 94f2b65 [giwa] remove waste duplicated code 580fbc2 [giwa] modified streaming test case to add coment 99e4bb3 [giwa] basic function test cases are passed 7051a84 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 35933e1 [giwa] broke something 9767712 [giwa] WIP: solved partitioned and None is not recognized 4f2d7e6 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test 33c0f94d [giwa] edited the comment to add more precise description 774f18d [giwa] removed unnesessary changes 3a671cc [giwa] remove export PYSPARK_PYTHON in spark submit 8efa266 [giwa] fixed PEP-008 violation fa75d71 [giwa] delete waste file 7f96294 [giwa] added basic operation test cases 3dda31a [giwa] WIP added test case 1f68b78 [giwa] WIP c05922c [giwa] WIP: added PythonTestInputStream 1fd12ae [giwa] WIP c880a33 [giwa] update comment 5d22c92 [giwa] WIP ea4b06b [giwa] initial commit for testcase 5a9b525 [giwa] clean up dstream.py 79c5809 [giwa] added stop in StreamingContext 189dcea [giwa] clean up examples b8d7d24 [giwa] implemented reduce and count function in Dstream b6468e6 [giwa] Removed the waste line b47b5fd [Ken Takagiwa] Kill py4j callback server properly 19ddcdd [Ken Takagiwa] tried to restart callback server c9fc124 [Tathagata Das] Added extra line. 4caae3f [Tathagata Das] Added missing file 4eff053 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. 5e822d4 [Ken Takagiwa] remove waste file aeaf8a5 [Ken Takagiwa] clean up codes 9fa249b [Ken Takagiwa] clean up code 05459c6 [Ken Takagiwa] fix map function a9f4ecb [Ken Takagiwa] added count operation but this implementation need double check d1ee6ca [Ken Takagiwa] edit python sparkstreaming example 0b8b7d0 [Ken Takagiwa] reduceByKey is working d25d5cf [Ken Takagiwa] added reducedByKey not working yet 7f7c5d1 [Ken Takagiwa] delete old file 967dc26 [Ken Takagiwa] fixed typo of network_workdcount.py 57fb740 [Ken Takagiwa] added doctest for pyspark.streaming.duration 4b69fb1 [Ken Takagiwa] fied input of socketTextDStream 02f618a [Ken Takagiwa] initial commit for socketTextStream 4ce4058 [Ken Takagiwa] remove unused import in python 856d98e [Ken Takagiwa] add empty line 490e338 [Ken Takagiwa] sorted the import following Spark coding convention 5594bd4 [Ken Takagiwa] revert pom.xml 2adca84 [Ken Takagiwa] remove not implemented DStream functions in python e551e13 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit 3758175 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit c5518b4 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 dcf243f [Ken Takagiwa] implementing transform function in Python 9af03f4 [Ken Takagiwa] added reducedByKey not working yet 6e0d9c7 [Ken Takagiwa] modify dstream.py to fix indent error e497b9b [Ken Takagiwa] comment PythonDStream.PairwiseDStream 5c3a683 [Ken] initial commit for pySparkStreaming 665bfdb [giwa] added testcase for combineByKey a3d2379 [giwa] added gorupByKey testcase 636090a [giwa] added sparkContext as input parameter in StreamingContext e7ebb08 [giwa] removed wasted print in DStream d8b593b [giwa] add comments ea9c873 [giwa] added TODO coments 89ae38a [giwa] added saveAsTextFiles and saveAsPickledFiles e3033fc [giwa] remove waste duplicated code a14c7e1 [giwa] modified streaming test case to add coment 536def4 [giwa] basic function test cases are passed 2112638 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4 080541a [giwa] broke something 0704b86 [giwa] WIP: solved partitioned and None is not recognized 90a6484 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test a65f302 [giwa] edited the comment to add more precise description bdde697 [giwa] removed unnesessary changes e8c7bfc [giwa] remove export PYSPARK_PYTHON in spark submit 3334169 [giwa] fixed PEP-008 violation db0a303 [giwa] delete waste file 2cfd3a0 [giwa] added basic operation test cases 90ae568 [giwa] WIP added test case a120d07 [giwa] WIP f671cdb [giwa] WIP: added PythonTestInputStream 56fae45 [giwa] WIP e35e101 [giwa] Merge branch 'master' into testcase ba5112d [giwa] update comment 28aa56d [giwa] WIP fb08559 [giwa] initial commit for testcase a613b85 [giwa] clean up dstream.py c40c0ef [giwa] added stop in StreamingContext 31e4260 [giwa] clean up examples d2127d6 [giwa] implemented reduce and count function in Dstream 48f7746 [giwa] Removed the waste line 0f83eaa [Ken Takagiwa] delete py4j 0.8.1 1679808 [Ken Takagiwa] Kill py4j callback server properly f96cd4e [Ken Takagiwa] tried to restart callback server fe86198 [Ken Takagiwa] add py4j 0.8.2.1 but server is not launched 1064fe0 [Ken Takagiwa] Merge branch 'master' of https://github.com/giwa/spark 28c6620 [Ken Takagiwa] Implemented DStream.foreachRDD in the Python API using Py4J callback server 85b0fe1 [Ken Takagiwa] Merge pull request #1 from tdas/python-foreach 54e2e8c [Tathagata Das] Added extra line. e185338 [Tathagata Das] Added missing file a778d4b [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server. cc2092b [Ken Takagiwa] remove waste file d042ac6 [Ken Takagiwa] clean up codes 84a021f [Ken Takagiwa] clean up code bd20e17 [Ken Takagiwa] fix map function d01a125 [Ken Takagiwa] added count operation but this implementation need double check 7d05109 [Ken Takagiwa] merge with remote branch ae464e0 [Ken Takagiwa] edit python sparkstreaming example 04af046 [Ken Takagiwa] reduceByKey is working 3b6d7b0 [Ken Takagiwa] implementing transform function in Python 571d52d [Ken Takagiwa] added reducedByKey not working yet 5720979 [Ken Takagiwa] delete old file e604fcb [Ken Takagiwa] fixed typo of network_workdcount.py 4b7c08b [Ken Takagiwa] Merge branch 'master' of https://github.com/giwa/spark ce7d426 [Ken Takagiwa] added doctest for pyspark.streaming.duration a8c9fd5 [Ken Takagiwa] fixed for socketTextStream a61fa9e [Ken Takagiwa] fied input of socketTextDStream 1e84f41 [Ken Takagiwa] initial commit for socketTextStream 6d012f7 [Ken Takagiwa] remove unused import in python 25d30d5 [Ken Takagiwa] add empty line 6e0a64a [Ken Takagiwa] sorted the import following Spark coding convention fa4a7fc [Ken Takagiwa] revert streaming/pom.xml 8f8202b [Ken Takagiwa] revert streaming pom.xml c9d79dd [Ken Takagiwa] revert pom.xml 57e3e52 [Ken Takagiwa] remove not implemented DStream functions in python 0a516f5 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit a7a0b5c [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit 72bfc66 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10 69e9cd3 [Ken Takagiwa] implementing transform function in Python 94a0787 [Ken Takagiwa] added reducedByKey not working yet 88068cf [Ken Takagiwa] modify dstream.py to fix indent error 1367be5 [Ken Takagiwa] comment PythonDStream.PairwiseDStream eb2b3ba [Ken] Merge remote-tracking branch 'upstream/master' d8e51f9 [Ken] initial commit for pySparkStreaming --- .../apache/spark/api/python/PythonRDD.scala | 10 +- .../main/python/streaming/hdfs_wordcount.py | 49 ++ .../python/streaming/network_wordcount.py | 48 ++ .../streaming/stateful_network_wordcount.py | 57 ++ python/docs/epytext.py | 2 +- python/docs/index.rst | 1 + python/docs/pyspark.rst | 3 +- python/pyspark/context.py | 8 +- python/pyspark/serializers.py | 3 + python/pyspark/streaming/__init__.py | 21 + python/pyspark/streaming/context.py | 325 +++++++++ python/pyspark/streaming/dstream.py | 621 ++++++++++++++++++ python/pyspark/streaming/tests.py | 545 +++++++++++++++ python/pyspark/streaming/util.py | 128 ++++ python/run-tests | 7 + .../streaming/api/java/JavaDStreamLike.scala | 2 +- .../streaming/api/python/PythonDStream.scala | 316 +++++++++ 17 files changed, 2133 insertions(+), 13 deletions(-) create mode 100644 examples/src/main/python/streaming/hdfs_wordcount.py create mode 100644 examples/src/main/python/streaming/network_wordcount.py create mode 100644 examples/src/main/python/streaming/stateful_network_wordcount.py create mode 100644 python/pyspark/streaming/__init__.py create mode 100644 python/pyspark/streaming/context.py create mode 100644 python/pyspark/streaming/dstream.py create mode 100644 python/pyspark/streaming/tests.py create mode 100644 python/pyspark/streaming/util.py create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index c74f86548ef85..4acbdf9d5e25f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials -import scala.reflect.ClassTag -import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -42,7 +40,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils private[spark] class PythonRDD( - parent: RDD[_], + @transient parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -55,9 +53,9 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = parent.partitions + override def getPartitions = firstParent.partitions - override val partitioner = if (preservePartitoning) parent.partitioner else None + override val partitioner = if (preservePartitoning) firstParent.partitioner else None override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -234,7 +232,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py new file mode 100644 index 0000000000000..40faff0ccc7db --- /dev/null +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in new text files created in the given directory + Usage: hdfs_wordcount.py + is the directory that Spark Streaming will use to find and read new text files. + + To run this on your local machine on directory `localdir`, run this example + $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir + + Then create a text file in `localdir` and the words in the file will get counted. +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, "Usage: hdfs_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingHDFSWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.textFileStream(sys.argv[1]) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda x: (x, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py new file mode 100644 index 0000000000000..cfa9c1ff5bfbc --- /dev/null +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingNetworkWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py new file mode 100644 index 0000000000000..18a9a5a452ffb --- /dev/null +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: stateful_network_wordcount.py + and describe the TCP server that Spark Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ + localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: stateful_network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") + ssc = StreamingContext(sc, 1) + ssc.checkpoint("checkpoint") + + def updateFunc(new_values, last_sum): + return sum(new_values) + (last_sum or 0) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + running_counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .updateStateByKey(updateFunc) + + running_counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/python/docs/epytext.py b/python/docs/epytext.py index 61d731bff570d..19fefbfc057a4 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -5,7 +5,7 @@ (r"L{([\w.()]+)}", r":class:`\1`"), (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), (r"C{([\w.()]+)}", r":class:`\1`"), - (r"[IBCM]{(.+)}", r"`\1`"), + (r"[IBCM]{([^}]+)}", r"`\1`"), ('pyspark.rdd.RDD', 'RDD'), ) diff --git a/python/docs/index.rst b/python/docs/index.rst index d66e051b15371..703bef644de28 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -13,6 +13,7 @@ Contents: pyspark pyspark.sql + pyspark.streaming pyspark.mllib diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index a68bd62433085..e81be3b6cb796 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -7,8 +7,9 @@ Subpackages .. toctree:: :maxdepth: 1 - pyspark.mllib pyspark.sql + pyspark.streaming + pyspark.mllib Contents -------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 85c04624da4a6..89d2e2e5b4a8e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -68,7 +68,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None): + gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -104,14 +104,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf) + conf, jsc) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf): + conf, jsc): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -154,7 +154,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._initialize_context(self._conf._jconf) + self._jsc = jsc or self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 3d1a34b281acc..08a0f0d8ffb3e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -114,6 +114,9 @@ def __ne__(self, other): def __repr__(self): return "<%s object>" % self.__class__.__name__ + def __hash__(self): + return hash(str(self)) + class FramedSerializer(Serializer): diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py new file mode 100644 index 0000000000000..d2644a1d4ffab --- /dev/null +++ b/python/pyspark/streaming/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.dstream import DStream + +__all__ = ['StreamingContext', 'DStream'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py new file mode 100644 index 0000000000000..dc9dc41121935 --- /dev/null +++ b/python/pyspark/streaming/context.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys + +from py4j.java_collections import ListConverter +from py4j.java_gateway import java_import, JavaObject + +from pyspark import RDD, SparkConf +from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.dstream import DStream +from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer + +__all__ = ["StreamingContext"] + + +def _daemonize_callback_server(): + """ + Hack Py4J to daemonize callback server + + The thread of callback server has daemon=False, it will block the driver + from exiting if it's not shutdown. The following code replace `start()` + of CallbackServer with a new version, which set daemon=True for this + thread. + + Also, it will update the port number (0) with real port + """ + # TODO: create a patch for Py4J + import socket + import py4j.java_gateway + logger = py4j.java_gateway.logger + from py4j.java_gateway import Py4JNetworkError + from threading import Thread + + def start(self): + """Starts the CallbackServer. This method should be called by the + client instead of run().""" + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + 1) + try: + self.server_socket.bind((self.address, self.port)) + if not self.port: + # update port with real port + self.port = self.server_socket.getsockname()[1] + except Exception as e: + msg = 'An error occurred while trying to start the callback server: %s' % e + logger.exception(msg) + raise Py4JNetworkError(msg) + + # Maybe thread needs to be cleanup up? + self.thread = Thread(target=self.run) + self.thread.daemon = True + self.thread.start() + + py4j.java_gateway.CallbackServer.start = start + + +class StreamingContext(object): + """ + Main entry point for Spark Streaming functionality. A StreamingContext + represents the connection to a Spark cluster, and can be used to create + L{DStream} various input sources. It can be from an existing L{SparkContext}. + After creating and transforming DStreams, the streaming computation can + be started and stopped using `context.start()` and `context.stop()`, + respectively. `context.awaitTransformation()` allows the current thread + to wait for the termination of the context by `stop()` or by an exception. + """ + _transformerSerializer = None + + def __init__(self, sparkContext, batchDuration=None, jssc=None): + """ + Create a new StreamingContext. + + @param sparkContext: L{SparkContext} object. + @param batchDuration: the time interval (in seconds) at which streaming + data will be divided into batches + """ + + self._sc = sparkContext + self._jvm = self._sc._jvm + self._jssc = jssc or self._initialize_context(self._sc, batchDuration) + + def _initialize_context(self, sc, duration): + self._ensure_initialized() + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ + return self._jvm.Duration(int(seconds * 1000)) + + @classmethod + def _ensure_initialized(cls): + SparkContext._ensure_initialized() + gw = SparkContext._gateway + + java_import(gw.jvm, "org.apache.spark.streaming.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() + # use random port + gw._start_callback_server(0) + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + + # register serializer for TransformFunction + # it happens before creating SparkContext when loading from checkpointing + cls._transformerSerializer = TransformFunctionSerializer( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + + @classmethod + def getOrCreate(cls, checkpointPath, setupFunc): + """ + Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + recreated from the checkpoint data. If the data does not exist, then the provided setupFunc + will be used to create a JavaStreamingContext. + + @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc Function to create a new JavaStreamingContext and setup DStreams + """ + # TODO: support checkpoint in HDFS + if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + ssc = setupFunc() + ssc.checkpoint(checkpointPath) + return ssc + + cls._ensure_initialized() + gw = SparkContext._gateway + + try: + jssc = gw.jvm.JavaStreamingContext(checkpointPath) + except Exception: + print >>sys.stderr, "failed to load StreamingContext from checkpoint" + raise + + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # update ctx in serializer + SparkContext._active_spark_context = sc + cls._transformerSerializer.ctx = sc + return StreamingContext(sc, None, jssc) + + @property + def sparkContext(self): + """ + Return SparkContext which is associated with this StreamingContext. + """ + return self._sc + + def start(self): + """ + Start the execution of the streams. + """ + self._jssc.start() + + def awaitTermination(self, timeout=None): + """ + Wait for the execution to stop. + @param timeout: time to wait in seconds + """ + if timeout is None: + self._jssc.awaitTermination() + else: + self._jssc.awaitTermination(int(timeout * 1000)) + + def stop(self, stopSparkContext=True, stopGraceFully=False): + """ + Stop the execution of the streams, with option of ensuring all + received data has been processed. + + @param stopSparkContext: Stop the associated SparkContext or not + @param stopGracefully: Stop gracefully by waiting for the processing + of all received data to be completed + """ + self._jssc.stop(stopSparkContext, stopGraceFully) + if stopSparkContext: + self._sc.stop() + + def remember(self, duration): + """ + Set each DStreams in this context to remember RDDs it generated + in the last given duration. DStreams remember RDDs only for a + limited duration of time and releases them for garbage collection. + This method allows the developer to specify how to long to remember + the RDDs (if the developer wishes to query old data outside the + DStream computation). + + @param duration: Minimum duration (in seconds) that each DStream + should remember its RDDs + """ + self._jssc.remember(self._jduration(duration)) + + def checkpoint(self, directory): + """ + Sets the context to periodically checkpoint the DStream operations for master + fault-tolerance. The graph will be checkpointed every batch interval. + + @param directory: HDFS-compatible directory where the checkpoint data + will be reliably stored + """ + self._jssc.checkpoint(directory) + + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input from TCP source hostname:port. Data is received using + a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited + lines. + + @param hostname: Hostname to connect to for receiving data + @param port: Port to connect to for receiving data + @param storageLevel: Storage level to use for storing the received objects + """ + jlevel = self._sc._getJavaStorageLevel(storageLevel) + return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, + UTF8Deserializer()) + + def textFileStream(self, directory): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as text files. Files must be wrriten to the + monitored directory by "moving" them from another location within the same + file system. File names starting with . are ignored. + """ + return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + + def _check_serializers(self, rdds): + # make sure they have same serializer + if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: + for i in range(len(rdds)): + # reset them to sc.serializer + rdds[i] = rdds[i]._reserialize() + + def queueStream(self, rdds, oneAtATime=True, default=None): + """ + Create an input stream from an queue of RDDs or list. In each batch, + it will process either one or all of the RDDs returned by the queue. + + NOTE: changes to the queue after the stream is created will not be recognized. + + @param rdds: Queue of RDDs + @param oneAtATime: pick one rdd each time or pick all of them once. + @param default: The default rdd if no more in rdds + """ + if default and not isinstance(default, RDD): + default = self._sc.parallelize(default) + + if not rdds and default: + rdds = [rdds] + + if rdds and not isinstance(rdds[0], RDD): + rdds = [self._sc.parallelize(input) for input in rdds] + self._check_serializers(rdds) + + jrdds = ListConverter().convert([r._jrdd for r in rdds], + SparkContext._gateway._gateway_client) + queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + if default: + default = default._reserialize(rdds[0]._jrdd_deserializer) + jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) + else: + jdstream = self._jssc.queueStream(queue, oneAtATime) + return DStream(jdstream, self, rdds[0]._jrdd_deserializer) + + def transform(self, dstreams, transformFunc): + """ + Create a new DStream in which each RDD is generated by applying + a function on RDDs of the DStreams. The order of the JavaRDDs in + the transform function parameter will be the same as the order + of corresponding DStreams in the list. + """ + jdstreams = ListConverter().convert([d._jdstream for d in dstreams], + SparkContext._gateway._gateway_client) + # change the final serializer to sc.serializer + func = TransformFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + jfunc = self._jvm.TransformFunction(func) + jdstream = self._jssc.transform(jdstreams, jfunc) + return DStream(jdstream, self, self._sc.serializer) + + def union(self, *dstreams): + """ + Create a unified DStream from multiple DStreams of the same + type and same slide duration. + """ + if not dstreams: + raise ValueError("should have at least one DStream to union") + if len(dstreams) == 1: + return dstreams[0] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + raise ValueError("All DStreams should have same serializer") + if len(set(s._slideDuration for s in dstreams)) > 1: + raise ValueError("All DStreams should have same slide duration") + first = dstreams[0] + jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], + SparkContext._gateway._gateway_client) + return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py new file mode 100644 index 0000000000000..5ae5cf07f0137 --- /dev/null +++ b/python/pyspark/streaming/dstream.py @@ -0,0 +1,621 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from itertools import chain, ifilter, imap +import operator +import time +from datetime import datetime + +from py4j.protocol import Py4JJavaError + +from pyspark import RDD +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.util import rddToFileName, TransformFunction +from pyspark.rdd import portable_hash +from pyspark.resultiterable import ResultIterable + +__all__ = ["DStream"] + + +class DStream(object): + """ + A Discretized Stream (DStream), the basic abstraction in Spark Streaming, + is a continuous sequence of RDDs (of the same type) representing a + continuous stream of data (see L{RDD} in the Spark core documentation + for more details on RDDs). + + DStreams can either be created from live data (such as, data from TCP + sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + generated by transforming existing DStreams using operations such as + `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming + program is running, each DStream periodically generates a RDD, either + from live data or by transforming the RDD generated by a parent DStream. + + DStreams internally is characterized by a few basic properties: + - A list of other DStreams that the DStream depends on + - A time interval at which the DStream generates an RDD + - A function that is used to generate an RDD after each time interval + """ + def __init__(self, jdstream, ssc, jrdd_deserializer): + self._jdstream = jdstream + self._ssc = ssc + self._sc = ssc._sc + self._jrdd_deserializer = jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + + def context(self): + """ + Return the StreamingContext associated with this DStream + """ + return self._ssc + + def count(self): + """ + Return a new DStream in which each RDD has a single element + generated by counting each RDD of this DStream. + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) + + def filter(self, f): + """ + Return a new DStream containing only the elements that satisfy predicate. + """ + def func(iterator): + return ifilter(f, iterator) + return self.mapPartitions(func, True) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to all elements of + this DStream, and then flattening the results + """ + def func(s, iterator): + return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def map(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each element of DStream. + """ + def func(iterator): + return imap(f, iterator) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitions() to each RDDs of this DStream. + """ + def func(s, iterator): + return f(iterator) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitionsWithIndex() to each RDDs of this DStream. + """ + return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) + + def reduce(self, func): + """ + Return a new DStream in which each RDD has a single element + generated by reducing each RDD of this DStream. + """ + return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) + + def reduceByKey(self, func, numPartitions=None): + """ + Return a new DStream by applying reduceByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.combineByKey(lambda x: x, func, func, numPartitions) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numPartitions=None): + """ + Return a new DStream by applying combineByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def func(rdd): + return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) + return self.transform(func) + + def partitionBy(self, numPartitions, partitionFunc=portable_hash): + """ + Return a copy of the DStream in which each RDD are partitioned + using the specified partitioner. + """ + return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.func_code.co_argcount == 1: + old_func = func + func = lambda t, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def pprint(self): + """ + Print the first ten elements of each RDD generated in this DStream. + """ + def takeAndPrint(time, rdd): + taken = rdd.take(11) + print "-------------------------------------------" + print "Time: %s" % time + print "-------------------------------------------" + for record in taken[:10]: + print record + if len(taken) > 10: + print "..." + print + + self.foreachRDD(takeAndPrint) + + def mapValues(self, f): + """ + Return a new DStream by applying a map function to the value of + each key-value pairs in this DStream without changing the key. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + def flatMapValues(self, f): + """ + Return a new DStream by applying a flatmap function to the value + of each key-value pairs in this DStream without changing the key. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def glom(self): + """ + Return a new DStream in which RDD is generated by applying glom() + to RDD of this DStream. + """ + def func(iterator): + yield list(iterator) + return self.mapPartitions(func) + + def cache(self): + """ + Persist the RDDs of this DStream with the default storage level + (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self.persist(StorageLevel.MEMORY_ONLY_SER) + return self + + def persist(self, storageLevel): + """ + Persist the RDDs of this DStream with the given storage level + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdstream.persist(javaStorageLevel) + return self + + def checkpoint(self, interval): + """ + Enable periodic checkpointing of RDDs of this DStream + + @param interval: time in seconds, after each period of that, generated + RDD will be checkpointed + """ + self.is_checkpointed = True + self._jdstream.checkpoint(self._ssc._jduration(interval)) + return self + + def groupByKey(self, numPartitions=None): + """ + Return a new DStream by applying groupByKey on each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) + + def countByValue(self): + """ + Return a new DStream in which each RDD contains the counts of each + distinct value in each RDD of this DStream. + """ + return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + + def saveAsTextFiles(self, prefix, suffix=None): + """ + Save each RDD in this DStream as at text file, using string + representation of elements. + """ + def saveAsTextFile(t, rdd): + path = rddToFileName(prefix, suffix, t) + try: + rdd.saveAsTextFile(path) + except Py4JJavaError as e: + # after recovered from checkpointing, the foreachRDD may + # be called twice + if 'FileAlreadyExistsException' not in str(e): + raise + return self.foreachRDD(saveAsTextFile) + + # TODO: uncomment this until we have ssc.pickleFileStream() + # def saveAsPickleFiles(self, prefix, suffix=None): + # """ + # Save each RDD in this DStream as at binary file, the elements are + # serialized by pickle. + # """ + # def saveAsPickleFile(t, rdd): + # path = rddToFileName(prefix, suffix, t) + # try: + # rdd.saveAsPickleFile(path) + # except Py4JJavaError as e: + # # after recovered from checkpointing, the foreachRDD may + # # be called twice + # if 'FileAlreadyExistsException' not in str(e): + # raise + # return self.foreachRDD(saveAsPickleFile) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.func_code.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.func_code.co_argcount == 2, "func should take one or two arguments" + return TransformedDStream(self, func) + + def transformWith(self, func, other, keepSerializer=False): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream and 'other' DStream. + + `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three + arguments of (`time`, `rdd_a`, `rdd_b`) + """ + if func.func_code.co_argcount == 2: + oldfunc = func + func = lambda t, a, b: oldfunc(a, b) + assert func.func_code.co_argcount == 3, "func should take two or three arguments" + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), + other._jdstream.dstream(), jfunc) + jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer + return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) + + def repartition(self, numPartitions): + """ + Return a new DStream with an increased or decreased level of parallelism. + """ + return self.transform(lambda rdd: rdd.repartition(numPartitions)) + + @property + def _slideDuration(self): + """ + Return the slideDuration in seconds of this DStream + """ + return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0 + + def union(self, other): + """ + Return a new DStream by unifying data of another DStream with this DStream. + + @param other: Another DStream having the same interval (i.e., slideDuration) + as this DStream. + """ + if self._slideDuration != other._slideDuration: + raise ValueError("the two DStream should have same slide duration") + return self.transformWith(lambda a, b: a.union(b), other, True) + + def cogroup(self, other, numPartitions=None): + """ + Return a new DStream by applying 'cogroup' between RDDs of this + DStream and `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) + + def join(self, other, numPartitions=None): + """ + Return a new DStream by applying 'join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.join(b, numPartitions), other) + + def leftOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'left outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) + + def rightOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'right outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) + + def fullOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'full outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) + + def _jtime(self, timestamp): + """ Convert datetime or unix_timestamp into Time + """ + if isinstance(timestamp, datetime): + timestamp = time.mktime(timestamp.timetuple()) + return self._sc._jvm.Time(long(timestamp * 1000)) + + def slice(self, begin, end): + """ + Return all the RDDs between 'begin' to 'end' (both included) + + `begin`, `end` could be datetime.datetime() or unix_timestamp + """ + jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) + return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] + + def _validate_window_param(self, window, slide): + duration = self._jdstream.dstream().slideDuration().milliseconds() + if int(window * 1000) % duration != 0: + raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" + % duration) + if slide and int(slide * 1000) % duration != 0: + raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" + % duration) + + def window(self, windowDuration, slideDuration=None): + """ + Return a new DStream in which each RDD contains all the elements in seen in a + sliding window of time over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + self._validate_window_param(windowDuration, slideDuration) + d = self._ssc._jduration(windowDuration) + if slideDuration is None: + return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) + s = self._ssc._jduration(slideDuration) + return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) + + def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated by reducing all + elements in a sliding window over this DStream. + + if `invReduceFunc` is not None, the reduction is done incrementally + using the old window's reduced value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse reduce function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + keyed = self.map(lambda x: (1, x)) + reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, + windowDuration, slideDuration, 1) + return reduced.map(lambda (k, v): v) + + def countByWindow(self, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated + by counting the number of elements in a window over this DStream. + windowDuration and slideDuration are as defined in the window() operation. + + This is equivalent to window(windowDuration, slideDuration).count(), + but will be more efficient if window is large. + """ + return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub, + windowDuration, slideDuration) + + def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream in which each RDD contains the count of distinct elements in + RDDs in a sliding window over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + """ + keyed = self.map(lambda x: (x, 1)) + counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, + windowDuration, slideDuration, numPartitions) + return counted.filter(lambda (k, v): v > 0).count() + + def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream by applying `groupByKey` over a sliding window. + Similar to `DStream.groupByKey()`, but applies it over a sliding window. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: Number of partitions of each RDD in the new DStream. + """ + ls = self.mapValues(lambda x: [x]) + grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):], + windowDuration, slideDuration, numPartitions) + return grouped.mapValues(ResultIterable) + + def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None, + numPartitions=None, filterFunc=None): + """ + Return a new DStream by applying incremental `reduceByKey` over a sliding window. + + The reduced value of over a new window is calculated using the old window's reduce value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + + `invFunc` can be None, then it will reduce all the RDDs in window, could be slower + than having `invFunc`. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + @param filterFunc: function to filter expired key-value pairs; + only pairs that satisfy the function are retained + set this to null if you do not want to filter + """ + self._validate_window_param(windowDuration, slideDuration) + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + reduced = self.reduceByKey(func, numPartitions) + + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) + if invReduceFunc: + jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + else: + jinvReduceFunc = None + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + def updateStateByKey(self, updateFunc, numPartitions=None): + """ + Return a new "state" DStream where the state for each key is updated by applying + the given function on the previous state of the key and the new values of the key. + + @param updateFunc: State update function. If this function returns None, then + corresponding state key-value pair will be eliminated. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def reduceFunc(t, a, b): + if a is None: + g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) + else: + g = a.cogroup(b, numPartitions) + g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) + state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) + return state.filter(lambda (k, v): v is not None) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, + self._sc.serializer, self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + +class TransformedDStream(DStream): + """ + TransformedDStream is an DStream generated by an Python function + transforming each RDD of an DStream to another RDDs. + + Multiple continuous transformations of DStream can be combined into + one transformation. + """ + def __init__(self, prev, func): + self._ssc = prev._ssc + self._sc = self._ssc._sc + self._jrdd_deserializer = self._sc.serializer + self.is_cached = False + self.is_checkpointed = False + self._jdstream_val = None + + if (isinstance(prev, TransformedDStream) and + not prev.is_cached and not prev.is_checkpointed): + prev_func = prev.func + self.func = lambda t, rdd: func(t, prev_func(t, rdd)) + self.prev = prev.prev + else: + self.prev = prev + self.func = func + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py new file mode 100644 index 0000000000000..a8d876d0fa3b3 --- /dev/null +++ b/python/pyspark/streaming/tests.py @@ -0,0 +1,545 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from itertools import chain +import time +import operator +import unittest +import tempfile + +from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.streaming.context import StreamingContext + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 10 # seconds + duration = 1 + + def setUp(self): + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + self.sc = SparkContext(appName=class_name, conf=conf) + self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + self.ssc.stop() + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print "timeout after", self.timeout + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) + + +class BasicOperationTests(PySparkStreamingTestCase): + + def test_map(self): + """Basic operation test for DStream.map.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.map(str) + expected = map(lambda x: map(str, x), input) + self._test_func(input, func, expected) + + def test_flatMap(self): + """Basic operation test for DStream.faltMap.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + input) + self._test_func(input, func, expected) + + def test_filter(self): + """Basic operation test for DStream.filter.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + self._test_func(input, func, expected) + + def test_count(self): + """Basic operation test for DStream.count.""" + input = [range(5), range(10), range(20)] + + def func(dstream): + return dstream.count() + expected = map(lambda x: [len(x)], input) + self._test_func(input, func, expected) + + def test_reduce(self): + """Basic operation test for DStream.reduce.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.reduce(operator.add) + expected = map(lambda x: [reduce(operator.add, x)], input) + self._test_func(input, func, expected) + + def test_reduceByKey(self): + """Basic operation test for DStream.reduceByKey.""" + input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], + [("", 1), ("", 1), ("", 1), ("", 1)], + [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] + + def func(dstream): + return dstream.reduceByKey(operator.add) + expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]] + self._test_func(input, func, expected, sort=True) + + def test_mapValues(self): + """Basic operation test for DStream.mapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] + self._test_func(input, func, expected, sort=True) + + def test_flatMapValues(self): + """Basic operation test for DStream.flatMapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), + ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] + self._test_func(input, func, expected) + + def test_glom(self): + """Basic operation test for DStream.glom.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.glom() + expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] + self._test_func(rdds, func, expected) + + def test_mapPartitions(self): + """Basic operation test for DStream.mapPartitions.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected = [[3, 7], [11, 15], [19, 23]] + self._test_func(rdds, func, expected) + + def test_countByValue(self): + """Basic operation test for DStream.countByValue.""" + input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + + def func(dstream): + return dstream.countByValue() + expected = [[4], [4], [3]] + self._test_func(input, func, expected) + + def test_groupByKey(self): + """Basic operation test for DStream.groupByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + return dstream.groupByKey().mapValues(list) + + expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + self._test_func(input, func, expected, sort=True) + + def test_combineByKey(self): + """Basic operation test for DStream.combineByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")], + [(1, "111"), (2, "11"), (3, "1")], + [("a", "11"), ("b", "1"), ("", "111")]] + self._test_func(input, func, expected, sort=True) + + def test_repartition(self): + input = [range(1, 5), range(5, 9)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.repartition(1).glom() + expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] + self._test_func(rdds, func, expected) + + def test_union(self): + input1 = [range(3), range(5), range(6)] + input2 = [range(3, 6), range(5, 6)] + + def func(d1, d2): + return d1.union(d2) + + expected = [range(6), range(6), range(6)] + self._test_func(input1, func, expected, input2=input2) + + def test_cogroup(self): + input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]] + input2 = [[(1, 2)], + [(4, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]] + + def func(d1, d2): + return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs))) + + expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], + [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], + [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]] + self._test_func(input, func, expected, sort=True, input2=input2) + + def test_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.join(b) + + expected = [[('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_left_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.leftOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_right_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.rightOuterJoin(b) + + expected = [[('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_full_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.fullOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_update_state_by_key(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + +class WindowFunctionTests(PySparkStreamingTestCase): + + timeout = 20 + + def test_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.window(3, 1).count() + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.countByWindow(3, 1) + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window_large(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByWindow(5, 1) + + expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] + self._test_func(input, func, expected) + + def test_count_by_value_and_window(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByValueAndWindow(5, 1) + + expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + self._test_func(input, func, expected) + + def test_group_by_key_and_window(self): + input = [[('a', i)] for i in range(5)] + + def func(dstream): + return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + + expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], + [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] + self._test_func(input, func, expected) + + def test_reduce_by_invalid_window(self): + input1 = [range(3), range(5), range(1), range(6)] + d1 = self.ssc.queueStream(input1) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + + def _add_input_stream(self): + inputs = map(lambda x: range(1, x), range(101)) + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + + def test_queue_stream(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], result) + + def test_union(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + +class CheckpointTests(PySparkStreamingTestCase): + + def setUp(self): + pass + + def test_get_or_create(self): + inputd = tempfile.mkdtemp() + outputd = tempfile.mkdtemp() + "/" + + def updater(vs, s): + return sum(vs, s or 0) + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) + wc = dstream.updateStateByKey(updater) + wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") + wc.checkpoint(.5) + return ssc + + cpd = tempfile.mkdtemp("test_streaming_cps") + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + + def check_output(n): + while not os.listdir(outputd): + time.sleep(0.1) + time.sleep(1) # make sure mtime is larger than the previous one + with open(os.path.join(inputd, str(n)), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) + + while True: + p = os.path.join(outputd, max(os.listdir(outputd))) + if '_SUCCESS' not in os.listdir(p): + # not finished + time.sleep(0.01) + continue + ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + d = ordd.values().map(int).collect() + if not d: + time.sleep(0.01) + continue + self.assertEqual(10, len(d)) + s = set(d) + self.assertEqual(1, len(s)) + m = s.pop() + if n > m: + continue + self.assertEqual(n, m) + break + + check_output(1) + check_output(2) + ssc.stop(True, True) + + time.sleep(1) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + check_output(3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py new file mode 100644 index 0000000000000..86ee5aa04f252 --- /dev/null +++ b/python/pyspark/streaming/util.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from datetime import datetime +import traceback + +from pyspark import SparkContext, RDD + + +class TransformFunction(object): + """ + This class wraps a function RDD[X] -> RDD[Y] that was passed to + DStream.transform(), allowing it to be called from Java via Py4J's + callback server. + + Java calls this function with a sequence of JavaRDDs and this function + returns a single JavaRDD pointer back to Java. + """ + _emptyRDD = None + + def __init__(self, ctx, func, *deserializers): + self.ctx = ctx + self.func = func + self.deserializers = deserializers + + def call(self, milliseconds, jrdds): + try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + if not self.ctx or not self.ctx._jsc: + # stopped + return + + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) + + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + for jrdd, ser in zip(jrdds, sers)] + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) + if r: + return r._jrdd + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunction(%s)" % self.func + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction'] + + +class TransformFunctionSerializer(object): + """ + This class implements a serializer for PythonTransformFunction Java + objects. + + This is necessary because the Java PythonTransformFunction objects are + actually Py4J references to Python objects and thus are not directly + serializable. When Java needs to serialize a PythonTransformFunction, + it uses this class to invoke Python, which returns the serialized function + as a byte array. + """ + def __init__(self, ctx, serializer, gateway=None): + self.ctx = ctx + self.serializer = serializer + self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) + + def dumps(self, id): + try: + func = self.gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return TransformFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] + + +def rddToFileName(prefix, suffix, timestamp): + """ + Return string prefix-time(.suffix) + + >>> rddToFileName("spark", None, 12345678910) + 'spark-12345678910' + >>> rddToFileName("spark", "tmp", 12345678910) + 'spark-12345678910.tmp' + """ + if isinstance(timestamp, datetime): + seconds = time.mktime(timestamp.timetuple()) + timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + if suffix is None: + return prefix + "-" + str(timestamp) + else: + return prefix + "-" + str(timestamp) + "." + suffix + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/run-tests b/python/run-tests index f6a96841175e8..2f98443c30aef 100755 --- a/python/run-tests +++ b/python/run-tests @@ -81,6 +81,11 @@ function run_mllib_tests() { run_test "pyspark/mllib/tests.py" } +function run_streaming_tests() { + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -96,6 +101,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_streaming_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -105,6 +111,7 @@ if [ $(which pypy) ]; then run_core_tests run_sql_tests + run_streaming_tests fi if [[ $FAILED == 0 ]]; then diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a6184de4e83c1..2a7004e56ef53 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } - /** + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala new file mode 100644 index 0000000000000..213dff6a76354 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.python + +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy +import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.language.existentials + +import py4j.GatewayServer + +import org.apache.spark.api.java._ +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Interval, Duration, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.api.java._ + + +/** + * Interface for Python callback function which is used to transform RDDs + */ +private[python] trait PythonTransformFunction { + def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] +} + +/** + * Interface for Python Serializer to serialize PythonTransformFunction + */ +private[python] trait PythonTransformFunctionSerializer { + def dumps(id: String): Array[Byte] + def loads(bytes: Array[Byte]): PythonTransformFunction +} + +/** + * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J) + * so that it looks like a Scala function and can be transparently serialized and + * deserialized by Java. + */ +private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) + .map(_.rdd) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + } + + // for function.Function2 + def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { + pfunc.call(time.milliseconds, rdds) + } + + private def writeObject(out: ObjectOutputStream): Unit = { + val bytes = PythonTransformFunctionSerializer.serialize(pfunc) + out.writeInt(bytes.length) + out.write(bytes) + } + + private def readObject(in: ObjectInputStream): Unit = { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + pfunc = PythonTransformFunctionSerializer.deserialize(bytes) + } +} + +/** + * Helpers for PythonTransformFunctionSerializer + * + * PythonTransformFunctionSerializer is logically a singleton that's happens to be + * implemented as a Python object. + */ +private[python] object PythonTransformFunctionSerializer { + + /** + * A serializer in Python, used to serialize PythonTransformFunction + */ + private var serializer: PythonTransformFunctionSerializer = _ + + /* + * Register a serializer from Python, should be called during initialization + */ + def register(ser: PythonTransformFunctionSerializer): Unit = { + serializer = ser + } + + def serialize(func: PythonTransformFunction): Array[Byte] = { + assert(serializer != null, "Serializer has not been registered!") + // get the id of PythonTransformFunction in py4j + val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) + val f = h.getClass().getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] + serializer.dumps(id) + } + + def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + assert(serializer != null, "Serializer has not been registered!") + serializer.loads(bytes) + } +} + +/** + * Helper functions, which are called from Python via Py4J. + */ +private[python] object PythonDStream { + + /** + * can not access PythonTransformFunctionSerializer.register() via Py4j + * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + */ + def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { + PythonTransformFunctionSerializer.register(ser) + } + + /** + * Update the port of callback client to `port` + */ + def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { + val cl = gws.getCallbackClient + val f = cl.getClass.getDeclaredField("port") + f.setAccessible(true) + f.setInt(cl, port) + } + + /** + * helper function for DStream.foreachRDD(), + * cannot be `foreachRDD`, it will confusing py4j + */ + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { + val func = new TransformFunction((pfunc)) + jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) + } + + /** + * convert list of RDD into queue of RDDs, for ssc.queueStream() + */ + def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { + val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] + rdds.forall(queue.add(_)) + queue + } +} + +/** + * Base class for PythonDStream with some common methods + */ +private[python] abstract class PythonDStream( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * Transformed DStream in Python. + */ +private[python] class PythonTransformedDStream ( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends PythonDStream(parent, pfunc) { + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(rdd, validTime) + } else { + None + } + } +} + +/** + * Transformed from two DStreams in Python. + */ +private[python] class PythonTransformed2DStream( + parent: DStream[_], + parent2: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent, parent2) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val empty: RDD[_] = ssc.sparkContext.emptyRDD + val rdd1 = parent.getOrCompute(validTime).getOrElse(empty) + val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty) + func(Some(rdd1), Some(rdd2), validTime) + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * similar to StateDStream + */ +private[python] class PythonStateDStream( + parent: DStream[Array[Byte]], + @transient reduceFunc: PythonTransformFunction) + extends PythonDStream(parent, reduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(lastState, rdd, validTime) + } else { + lastState + } + } +} + +/** + * similar to ReducedWindowedDStream + */ +private[python] class PythonReducedWindowedDStream( + parent: DStream[Array[Byte]], + @transient preduceFunc: PythonTransformFunction, + @transient pinvReduceFunc: PythonTransformFunction, + _windowDuration: Duration, + _slideDuration: Duration) + extends PythonDStream(parent, preduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + val invReduceFunc = new TransformFunction(pinvReduceFunc) + + def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val currentTime = validTime + val current = new Interval(currentTime - windowDuration, currentTime) + val previous = current - slideDuration + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + val previousRDD = getOrCompute(previous.endTime) + + // for small window, reduce once will be better than twice + if (pinvReduceFunc != null && previousRDD.isDefined + && windowDuration >= slideDuration * 5) { + + // subtract the values from old RDDs + val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime) + val subtracted = if (oldRDDs.size > 0) { + invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime) + } else { + previousRDD + } + + // add the RDDs of the reduced values in "new time steps" + val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) + if (newRDDs.size > 0) { + func(subtracted, Some(ssc.sc.union(newRDDs)), validTime) + } else { + subtracted + } + } else { + // Get the RDDs of the reduced values in current window + val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) + if (currentRDDs.size > 0) { + func(None, Some(ssc.sc.union(currentRDDs)), validTime) + } else { + None + } + } + } +} From 18bd67c24b081f113b34455692451571c466df92 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 12 Oct 2014 13:08:42 -0700 Subject: [PATCH 265/315] [SPARK-3887] Send stracktrace in ConnectionManager error replies When reporting that a remote error occurred, the ConnectionManager should also log the stacktrace of the remote exception. This PR accomplishes this by sending the remote exception's stacktrace as the payload in the "negative ACK / error message." Author: Josh Rosen Closes #2741 from JoshRosen/propagate-cm-exceptions-to-sender and squashes the following commits: b5366cc [Josh Rosen] Explicitly encode error messages using UTF-8. cef18b3 [Josh Rosen] [SPARK-3887] Send stracktrace in ConnectionManager error messages. --- .../spark/network/nio/ConnectionManager.scala | 10 ++++++---- .../org/apache/spark/network/nio/Message.scala | 14 ++++++++++++++ .../network/nio/NioBlockTransferService.scala | 11 ++++------- .../spark/network/nio/ConnectionManagerSuite.scala | 6 ++++-- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 6b00190c5eccc..9396b6ba84e7e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -748,9 +748,7 @@ private[nio] class ConnectionManager( } catch { case e: Exception => { logError(s"Exception was thrown while processing message", e) - val m = Message.createBufferMessage(bufferMessage.id) - m.hasError = true - ackMessage = Some(m) + ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) } } finally { sendMessage(connectionManagerId, ackMessage.getOrElse { @@ -913,8 +911,12 @@ private[nio] class ConnectionManager( } case scala.util.Success(ackMessage) => if (ackMessage.hasError) { + val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head + val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) + errorMsgByteBuf.get(errorMsgBytes) + val errorMsg = new String(errorMsgBytes, "utf-8") val e = new IOException( - "sendMessageReliably failed with ACK that signalled a remote error") + s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") if (!promise.tryFailure(e)) { logWarning("Ignore error because promise is completed", e) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index 0b874c2891255..3ad04591da658 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.util.Utils private[nio] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -84,6 +85,19 @@ private[nio] object Message { createBufferMessage(new Array[ByteBuffer](0), ackId) } + /** + * Create a "negative acknowledgment" to notify a sender that an error occurred + * while processing its message. The exception's stacktrace will be formatted + * as a string, serialized into a byte array, and sent as the message payload. + */ + def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { + val exceptionString = Utils.exceptionString(exception) + val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8")) + val errorMessage = createBufferMessage(serializedExceptionString, ackId) + errorMessage.hasError = true + errorMessage + } + def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { case BUFFER_MESSAGE => new BufferMessage(header.id, diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c6..5add4fc433fb3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -151,17 +151,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } catch { case e: Exception => { logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + Some(Message.createErrorMessage(e, msg.id)) } } case otherMessage: Any => - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" + logError(errorMsg) + Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) } } diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 9f49587cdc670..b70734dfe37cf 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -27,6 +27,7 @@ import scala.language.postfixOps import org.scalatest.FunSuite import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.Utils /** * Test the ConnectionManager with various security settings. @@ -236,7 +237,7 @@ class ConnectionManagerSuite extends FunSuite { val manager = new ConnectionManager(0, conf, securityManager) val managerServer = new ConnectionManager(0, conf, securityManager) managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception + throw new Exception("Custom exception text") }) val size = 10 * 1024 * 1024 @@ -246,9 +247,10 @@ class ConnectionManagerSuite extends FunSuite { val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - intercept[IOException] { + val exception = intercept[IOException] { Await.result(future, 1 second) } + assert(Utils.exceptionString(exception).contains("Custom exception text")) manager.stop() managerServer.stop() From e5be4de7bcf5aa7afc856fc665427ff2b22a0fcd Mon Sep 17 00:00:00 2001 From: NamelessAnalyst Date: Sun, 12 Oct 2014 14:18:55 -0700 Subject: [PATCH 266/315] SPARK-3716 [GraphX] Update Analytics.scala for partitionStrategy assignment Previously, when the val partitionStrategy was created it called a function in the Analytics object which was a copy of the PartitionStrategy.fromString() method. This function has been removed, and the assignment of partitionStrategy now uses the PartitionStrategy.fromString method instead. In this way, it better matches the declarations of edge/vertex StorageLevel variables. Author: NamelessAnalyst Closes #2569 from NamelessAnalyst/branch-1.1 and squashes the following commits: c24ff51 [NamelessAnalyst] Update Analytics.scala (cherry picked from commit 5a21e3e7e97f135c81c664098a723434b910f09d) Signed-off-by: Ankur Dave --- .../spark/examples/graphx/Analytics.scala | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index c4317a6aec798..45527d9382fd0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -46,17 +46,6 @@ object Analytics extends Logging { } val options = mutable.Map(optionsList: _*) - def pickPartitioner(v: String): PartitionStrategy = { - // TODO: Use reflection rather than listing all the partitioning strategies here. - v match { - case "RandomVertexCut" => RandomVertexCut - case "EdgePartition1D" => EdgePartition1D - case "EdgePartition2D" => EdgePartition2D - case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut - case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v) - } - } - val conf = new SparkConf() .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") @@ -67,7 +56,7 @@ object Analytics extends Logging { sys.exit(1) } val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy") - .map(pickPartitioner(_)) + .map(PartitionStrategy.fromString(_)) val edgeStorageLevel = options.remove("edgeStorageLevel") .map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY) val vertexStorageLevel = options.remove("vertexStorageLevel") @@ -107,7 +96,7 @@ object Analytics extends Logging { if (!outFname.isEmpty) { logWarning("Saving pageranks of pages to " + outFname) - pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname) + pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname) } sc.stop() @@ -129,7 +118,7 @@ object Analytics extends Logging { val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) val cc = ConnectedComponents.run(graph) - println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct()) + println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct()) sc.stop() case "triangles" => @@ -147,7 +136,7 @@ object Analytics extends Logging { minEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) - // TriangleCount requires the graph to be partitioned + // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) println("Triangles: " + triangles.vertices.map { From c86c9760374f331ab7ed173b0a022250635485d3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sun, 12 Oct 2014 15:41:27 -0700 Subject: [PATCH 267/315] [HOTFIX] Fix compilation error for Yarn 2.0.*-alpha This was reported in https://issues.apache.org/jira/browse/SPARK-3445. There are API differences between the 0.23.* and the 2.0.*-alpha branches that are not accounted for when this code was introduced. Author: Andrew Or Closes #2776 from andrewor14/fix-yarn-alpha and squashes the following commits: ec94752 [Andrew Or] Fix compilation error for 2.0.*-alpha --- .../src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5a20532315e59..5c7bca4541222 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -122,7 +122,7 @@ private[spark] class Client( * ApplicationReport#getClientToken is renamed `getClientToAMToken` in the stable API. */ override def getClientToken(report: ApplicationReport): String = - Option(report.getClientToken).getOrElse("") + Option(report.getClientToken).map(_.toString).getOrElse("") } object Client { From fc616d51a510f82627b5be949a5941419834cf70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Dubovsk=C3=BD?= Date: Sun, 12 Oct 2014 22:03:26 -0700 Subject: [PATCH 268/315] [SPARK-3121] Wrong implementation of implicit bytesWritableConverter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit val path = ... //path to seq file with BytesWritable as type of both key and value val file = sc.sequenceFile[Array[Byte],Array[Byte]](path) file.take(1)(0)._1 This prints incorrect content of byte array. Actual content starts with correct one and some "random" bytes and zeros are appended. BytesWritable has two methods: getBytes() - return content of all internal array which is often longer then actual value stored. It usually contains the rest of previous longer values copyBytes() - return just begining of internal array determined by internal length property It looks like in implicit conversion between BytesWritable and Array[byte] getBytes is used instead of correct copyBytes. dbtsai Author: Jakub DubovskĂ˝ Author: Dubovsky Jakub Closes #2712 from james64/3121-bugfix and squashes the following commits: f85d24c [Jakub DubovskĂ˝] Test name changed, comments added 1b20d51 [Jakub DubovskĂ˝] Import placed correctly 406e26c [Jakub DubovskĂ˝] Scala style fixed f92ffa6 [Dubovsky Jakub] performance tuning 480f9cd [Dubovsky Jakub] Bug 3121 fixed --- .../scala/org/apache/spark/SparkContext.scala | 6 ++- .../org/apache/spark/SparkContextSuite.scala | 40 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/SparkContextSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 396cdd1247e07..b709b8880ba76 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,6 +21,7 @@ import scala.language.implicitConversions import java.io._ import java.net.URI +import java.util.Arrays import java.util.concurrent.atomic.AtomicInteger import java.util.{Properties, UUID} import java.util.UUID.randomUUID @@ -1429,7 +1430,10 @@ object SparkContext extends Logging { simpleWritableConverter[Boolean, BooleanWritable](_.get) implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) } implicit def stringWritableConverter(): WritableConverter[String] = diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala new file mode 100644 index 0000000000000..31edad1c56c73 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import org.apache.hadoop.io.BytesWritable + +class SparkContextSuite extends FunSuite { + //Regression test for SPARK-3121 + test("BytesWritable implicit conversion is correct") { + val bytesWritable = new BytesWritable() + val inputArray = (1 to 10).map(_.toByte).toArray + bytesWritable.set(inputArray, 0, 10) + bytesWritable.set(inputArray, 0, 5) + + val converter = SparkContext.bytesWritableConverter() + val byteArray = converter.convert(bytesWritable) + assert(byteArray.length === 5) + + bytesWritable.set(inputArray, 0, 0) + val byteArray2 = converter.convert(bytesWritable) + assert(byteArray2.length === 0) + } +} From b4a7fa7a663c462bf537ca9d63af0dba6b4a8033 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sun, 12 Oct 2014 22:48:54 -0700 Subject: [PATCH 269/315] [SPARK-3905][Web UI]The keys for sorting the columns of Executor page ,Stage page Storage page are incorrect Author: GuoQiang Li Closes #2763 from witgo/SPARK-3905 and squashes the following commits: 17d7990 [GuoQiang Li] The keys for sorting the columns of Executor page ,Stage page Storage page are incorrect --- .../org/apache/spark/ui/jobs/ExecutorTable.scala | 12 ++++++------ .../scala/org/apache/spark/ui/jobs/StageTable.scala | 6 +++--- .../org/apache/spark/ui/storage/StoragePage.scala | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 2987dc04494a5..f0e43fbf70976 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -71,19 +71,19 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr - + - - - - - } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2e67310594784..4ee7f08ab47a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -176,9 +176,9 @@ private[ui] class StageTableBase( {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} - - - + + + } /** Render an HTML row that represents a stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 716591c9ed449..83489ca0679ee 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -58,9 +58,9 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - - - + + + // scalastyle:on } From d8b8c210786dfb905d06ea0a21d633f7772d5d1a Mon Sep 17 00:00:00 2001 From: Ken Takagiwa Date: Sun, 12 Oct 2014 23:05:14 -0700 Subject: [PATCH 270/315] Add echo "Run streaming tests ..." Author: Ken Takagiwa Closes #2778 from giwa/patch-2 and squashes the following commits: a59f9a1 [Ken Takagiwa] Add echo "Run streaming tests ..." --- python/run-tests | 1 + 1 file changed, 1 insertion(+) diff --git a/python/run-tests b/python/run-tests index 2f98443c30aef..80acd002ab7eb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -82,6 +82,7 @@ function run_mllib_tests() { } function run_streaming_tests() { + echo "Run streaming tests ..." run_test "pyspark/streaming/util.py" run_test "pyspark/streaming/tests.py" } From 92e017fb894be1e8e2b2b5274fec4c31a7a4412e Mon Sep 17 00:00:00 2001 From: w00228970 Date: Sun, 12 Oct 2014 23:35:50 -0700 Subject: [PATCH 271/315] [SPARK-3899][Doc]fix wrong links in streaming doc There are three [Custom Receiver Guide] links in streaming doc, the first is wrong. Author: w00228970 Author: wangfei Closes #2749 from scwf/streaming-doc and squashes the following commits: 0cd76b7 [wangfei] update link tojump to the Akka-specific section 45b0646 [w00228970] wrong link in streaming doc --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 5c21e912ea160..738309c668387 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -494,7 +494,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. -- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details. +- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. From 942847fd94c920f7954ddf01f97263926e512b0e Mon Sep 17 00:00:00 2001 From: omgteam Date: Mon, 13 Oct 2014 09:59:41 -0700 Subject: [PATCH 272/315] Bug Fix: without unpersist method in RandomForest.scala During trainning Gradient Boosting Decision Tree on large-scale sparse data, spark spill hundreds of data onto disk. And find the bug below: In version 1.1.0 DecisionTree.scala, train Method, treeInput has been persisted in Memory, but without unpersist. It caused heavy DISK usage. In github version(1.2.0 maybe), RandomForest.scala, train Method, baggedInput has been persisted but without unpersisted too. After added unpersist, it works right. https://issues.apache.org/jira/browse/SPARK-3918 Author: omgteam Closes #2775 from omgteam/master and squashes the following commits: 815d543 [omgteam] adjust tab to spaces 1a36f83 [omgteam] Bug: fix without unpersist baggedInput in RandomForest.scala --- .../main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index fa7a26f17c3ca..ebbd8e0257209 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -176,6 +176,8 @@ private class RandomForest ( timer.stop("findBestSplits") } + baggedInput.unpersist() + timer.stop("total") logInfo("Internal timing for DecisionTree:") From 39ccabacf11abdd9afc8f9895084c6707ff35c85 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 13 Oct 2014 11:50:42 -0700 Subject: [PATCH 273/315] [SPARK-3861][SQL] Avoid rebuilding hash tables for broadcast joins on each partition Author: Reynold Xin Closes #2727 from rxin/SPARK-3861-broadcast-hash-2 and squashes the following commits: 9c7b1a2 [Reynold Xin] Revert "Reuse CompactBuffer in UniqueKeyHashedRelation." 97626a1 [Reynold Xin] Reuse CompactBuffer in UniqueKeyHashedRelation. 7fcffb5 [Reynold Xin] Make UniqueKeyHashedRelation private[joins]. 18eb214 [Reynold Xin] Merge branch 'SPARK-3861-broadcast-hash' into SPARK-3861-broadcast-hash-1 4b9d0c9 [Reynold Xin] UniqueKeyHashedRelation.get should return null if the value is null. e0ebdd1 [Reynold Xin] Added a test case. 90b58c0 [Reynold Xin] [SPARK-3861] Avoid rebuilding hash tables on each partition 0c0082b [Reynold Xin] Fix line length. cbc664c [Reynold Xin] Rename join -> joins package. a070d44 [Reynold Xin] Fix line length in HashJoin a39be8c [Reynold Xin] [SPARK-3857] Create a join package for various join operators. --- .../execution/joins/BroadcastHashJoin.scala | 8 +- .../spark/sql/execution/joins/HashJoin.scala | 34 ++---- .../sql/execution/joins/HashedRelation.scala | 109 ++++++++++++++++++ .../execution/joins/ShuffledHashJoin.scala | 5 +- .../execution/joins/HashedRelationSuite.scala | 63 ++++++++++ 5 files changed, 187 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index d88ab6367a1b3..8fd35880eedfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -49,14 +49,16 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) + val input: Array[Row] = buildPlan.executeCollect() + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + sparkContext.broadcast(hashed) } override def execute() = { val broadcastRelation = Await.result(broadcastFuture, 5.minute) streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) + hashJoin(streamedIter, broadcastRelation.value) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 472b2e6ca6b4a..4012d757d5f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -43,34 +43,14 @@ trait HashJoin { override def output = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient protected lazy val streamSideKeyGenerator = + @transient protected lazy val buildSideKeyGenerator: Projection = + newProjection(buildKeys, buildPlan.output) + + @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = newMutableProjection(streamedKeys, streamedPlan.output) - protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - new Iterator[Row] { private[this] var currentStreamedRow: Row = _ private[this] var currentHashMatches: CompactBuffer[Row] = _ @@ -107,7 +87,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) + currentHashMatches = hashedRelation.get(joinKeys.currentValue) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala new file mode 100644 index 0000000000000..38b8993b03f82 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +/** + * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete + * object. + */ +private[joins] sealed trait HashedRelation { + def get(key: Row): CompactBuffer[Row] +} + + +/** + * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. + */ +private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Serializable { + + override def get(key: Row) = hashTable.get(key) +} + + +/** + * A specialized [[HashedRelation]] that maps key into a single value. This implementation + * assumes the key is unique. + */ +private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Serializable { + + override def get(key: Row) = { + val v = hashTable.get(key) + if (v eq null) null else CompactBuffer(v) + } + + def getValue(key: Row): Row = hashTable.get(key) +} + + +// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. + + +private[joins] object HashedRelation { + + def apply( + input: Iterator[Row], + keyGenerator: Projection, + sizeEstimate: Int = 64): HashedRelation = { + + // TODO: Use Spark's HashMap implementation. + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate) + var currentRow: Row = null + + // Whether the join key is unique. If the key is unique, we can convert the underlying + // hash map into one specialized for this. + var keyIsUnique = true + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + currentRow = input.next() + val rowKey = keyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += currentRow.copy() + } + } + + if (keyIsUnique) { + val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueKeyHashedRelation(uniqHashTable) + } else { + new GeneralHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 8247304c1dc2c..418c1c23e5546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -42,8 +42,9 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + hashJoin(streamIter, hashed) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala new file mode 100644 index 0000000000000..2aad01ded1acf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +class HashedRelationSuite extends FunSuite { + + // Key is simply the record itself + private val keyProjection = new Projection { + override def apply(row: Row): Row = row + } + + test("GeneralHashedRelation") { + val data = Array(Row(0), Row(1), Row(2), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[GeneralHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(Row(10)) === null) + + val data2 = CompactBuffer[Row](data(2)) + data2 += data(2) + assert(hashed.get(data(2)) == data2) + } + + test("UniqueKeyHashedRelation") { + val data = Array(Row(0), Row(1), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) + assert(hashed.get(Row(10)) === null) + + val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] + assert(uniqHashed.getValue(data(0)) == data(0)) + assert(uniqHashed.getValue(data(1)) == data(1)) + assert(uniqHashed.getValue(data(2)) == data(2)) + assert(uniqHashed.getValue(Row(10)) == null) + } +} From 49bbdcb660edff7522430b329a300765164ccc44 Mon Sep 17 00:00:00 2001 From: yingjieMiao Date: Mon, 13 Oct 2014 13:11:55 -0700 Subject: [PATCH 274/315] [Spark] RDD take() method: overestimate too much In the comment (Line 1083), it says: "Otherwise, interpolate the number of partitions we need to try, but overestimate it by 50%." `(1.5 * num * partsScanned / buf.size).toInt` is the guess of "num of total partitions needed". In every iteration, we should consider the increment `(1.5 * num * partsScanned / buf.size).toInt - partsScanned` Existing implementation 'exponentially' grows `partsScanned ` ( roughly: `x_{n+1} >= (1.5 + 1) x_n`) This could be a performance problem. (unless this is the intended behavior) Author: yingjieMiao Closes #2648 from yingjieMiao/rdd_take and squashes the following commits: d758218 [yingjieMiao] scala style fix a8e74bb [yingjieMiao] python style fix 4b6e777 [yingjieMiao] infix operator style fix 4391d3b [yingjieMiao] typo fix. 692f4e6 [yingjieMiao] cap numPartsToTry c4483dc [yingjieMiao] style fix 1d2c410 [yingjieMiao] also change in rdd.py and AsyncRDD d31ff7e [yingjieMiao] handle the edge case after 1 iteration a2aa36b [yingjieMiao] RDD take method: overestimate too much --- .../scala/org/apache/spark/rdd/AsyncRDDActions.scala | 12 +++++++----- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 +++++--- python/pyspark/rdd.py | 5 ++++- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index b62f3fbdc4a15..ede5568493cc0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -78,16 +78,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. + // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. + // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / results.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max(1, + (1.5 * num * partsScanned / results.size).toInt - partsScanned) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - results.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2aba40d152e3e..71cabf61d4ee0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag]( // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, + // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, // interpolate the number of partitions we need to try, but overestimate it by 50%. + // We also cap the estimation in the end. if (buf.size == 0) { numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index e13bab946c44a..15be4bfec92f9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1070,10 +1070,13 @@ def take(self, num): # If we didn't find any rows after the previous iteration, # quadruple and retry. Otherwise, interpolate the number of # partitions we need to try, but overestimate it by 50%. + # We also cap the estimation in the end. if len(items) == 0: numPartsToTry = partsScanned * 4 else: - numPartsToTry = int(1.5 * num * partsScanned / len(items)) + # the first paramter of max is >=1 whenever partsScanned >= 2 + numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned + numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4) left = num - len(items) From 46db277cc14bf3c1e4c4779baa8a40189b332d89 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 13 Oct 2014 13:31:14 -0700 Subject: [PATCH 275/315] [SPARK-3892][SQL] remove redundant type name Author: Daoyuan Wang Closes #2747 from adrian-wang/typename and squashes the following commits: 2824216 [Daoyuan Wang] remove redundant typeName fbaf340 [Daoyuan Wang] typename --- .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 1d375b8754182..5bdacab664f8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -349,7 +349,6 @@ case object FloatType extends FractionalType { object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) - def typeName: String = "array" } /** @@ -395,8 +394,6 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) { object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - - def typeName = "struct" } case class StructType(fields: Seq[StructField]) extends DataType { @@ -459,8 +456,6 @@ object MapType { */ def apply(keyType: DataType, valueType: DataType): MapType = MapType(keyType: DataType, valueType: DataType, true) - - def simpleName = "map" } /** From 2ac40da3f9fa6d45a59bb45b41606f1931ac5e81 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 13 Oct 2014 13:33:12 -0700 Subject: [PATCH 276/315] [SPARK-3407][SQL]Add Date type support Author: Daoyuan Wang Closes #2344 from adrian-wang/date and squashes the following commits: f15074a [Daoyuan Wang] remove outdated lines 2038085 [Daoyuan Wang] update return type 00fe81f [Daoyuan Wang] address lian cheng's comments 0df6ea1 [Daoyuan Wang] rebase and remove simple string bb1b1ef [Daoyuan Wang] remove failing test aa96735 [Daoyuan Wang] not cast for same type compare 30bf48b [Daoyuan Wang] resolve rebase conflict 617d1a8 [Daoyuan Wang] add date_udf case to white list c37e848 [Daoyuan Wang] comment update 5429212 [Daoyuan Wang] change to long f8f219f [Daoyuan Wang] revise according to Cheng Hao 0e0a4f5 [Daoyuan Wang] minor format 4ddcb92 [Daoyuan Wang] add java api for date 0e3110e [Daoyuan Wang] try to fix timezone issue 17fda35 [Daoyuan Wang] set test list 2dfbb5b [Daoyuan Wang] support date type --- .../spark/sql/catalyst/ScalaReflection.scala | 5 +- .../catalyst/analysis/HiveTypeCoercion.scala | 29 +++- .../spark/sql/catalyst/dsl/package.scala | 6 +- .../spark/sql/catalyst/expressions/Cast.scala | 98 +++++++++++-- .../sql/catalyst/expressions/literals.scala | 3 +- .../spark/sql/catalyst/types/dataTypes.scala | 12 +- .../ExpressionEvaluationSuite.scala | 35 ++++- .../apache/spark/sql/api/java/DataType.java | 5 + .../apache/spark/sql/api/java/DateType.java | 27 ++++ .../spark/sql/columnar/ColumnAccessor.scala | 4 + .../spark/sql/columnar/ColumnBuilder.scala | 3 + .../spark/sql/columnar/ColumnStats.scala | 20 ++- .../spark/sql/columnar/ColumnType.scala | 28 +++- .../scala/org/apache/spark/sql/package.scala | 10 ++ .../sql/types/util/DataTypeConversions.scala | 3 + .../sql/ScalaReflectionRelationSuite.scala | 5 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 1 + .../spark/sql/columnar/ColumnTypeSuite.scala | 7 +- .../sql/columnar/ColumnarTestUtils.scala | 3 +- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../execution/HiveCompatibilitySuite.scala | 10 ++ .../apache/spark/sql/hive/HiveContext.scala | 6 +- .../spark/sql/hive/HiveInspectors.scala | 9 ++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 + .../org/apache/spark/sql/hive/HiveQl.scala | 8 + .../date_1-0-23edf29bf7376c70d5ecf12720f4b1eb | 0 .../date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b | 0 ...date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 | 1 + ...date_1-11-480c5f024a28232b7857be327c992509 | 1 + ...date_1-12-4c0ed7fcb75770d8790575b586bf14f4 | 1 + .../date_1-13-44fc74c1993062c0a9522199ff27fea | 1 + ...date_1-14-4855a66124b16d1d0d003235995ac06b | 1 + ...date_1-15-8bc190dba0f641840b5e1e198a14c55b | 1 + ...date_1-16-23edf29bf7376c70d5ecf12720f4b1eb | 0 .../date_1-2-abdce0c0d14d3fc7441b7c134b02f99a | 0 .../date_1-3-df16364a220ff96a6ea1cd478cbc1d0b | 1 + .../date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 | 1 + .../date_1-5-5e70fc74158fbfca38134174360de12d | 0 .../date_1-6-df16364a220ff96a6ea1cd478cbc1d0b | 1 + .../date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 | 1 + .../date_1-8-1d5c58095cd52ea539d869f2ab1ab67d | 0 .../date_1-9-df16364a220ff96a6ea1cd478cbc1d0b | 1 + .../date_2-3-eedb73e0a622c2ab760b524f395dd4ba | 137 ++++++++++++++++++ .../date_2-4-3618dfde8da7c26f03bca72970db9ef7 | 137 ++++++++++++++++++ .../date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c | 12 ++ .../date_2-6-f4edce7cb20f325e8b69e787b2ae8882 | 0 .../date_3-3-4cf49e71b636df754871a675f9e4e24 | 0 .../date_3-4-e009f358964f6d1236cfc03283e2b06f | 1 + .../date_3-5-c26de4559926ddb0127d2dc5ea154774 | 0 .../date_4-0-b84f7e931d710dcbe3c5126d998285a8 | 0 .../date_4-1-6272f5e518f6a20bc96a5870ff315c4f | 0 .../date_4-2-4a0e7bde447ef616b98e0f55d2886de0 | 0 .../date_4-3-a23faa56b5d3ca9063a21f72b4278b00 | 0 .../date_4-4-bee09a7384666043621f68297cee2e68 | 1 + .../date_4-5-b84f7e931d710dcbe3c5126d998285a8 | 0 ...parison-0-69eec445bd045c9dc899fafa348d8495 | 1 + ...parison-1-fcc400871a502009c8680509e3869ec1 | 1 + ...arison-10-a9f2560c273163e11306d4f1dd1d9d54 | 1 + ...arison-11-4a7bac9ddcf40db6329faaec8e426543 | 1 + ...parison-2-b8598a4d0c948c2ddcf3eeef0abf2264 | 1 + ...parison-3-14d35f266be9cceb11a2ae09ec8b3835 | 1 + ...parison-4-c8865b14d53f2c2496fb69ee8191bf37 | 1 + ...parison-5-f2c907e64da8166a731ddc0ed19bad6c | 1 + ...parison-6-5606505a92bad10023ad9a3ef77eacc9 | 1 + ...mparison-7-47913d4aaf0d468ab3764cc3bfd68eb | 1 + ...parison-8-1e5ce4f833b6fba45618437c8fb7643c | 1 + ...parison-9-bcd987341fc1c38047a27d29dac6ae7c | 1 + ...e_join1-3-f71c7be760fb4de4eff8225f2c6614b2 | 22 +++ ...te_join1-4-70b9b49c55699fe94cfde069f5d197c | 0 ..._serde-10-d80e681519dcd8f5078c5602bb5befa9 | 0 ..._serde-11-29540200936bba47f17553547b409af7 | 0 ..._serde-12-c3c3275658b89d31fc504db31ae9f99c | 0 ..._serde-13-6c546456c81e635b6753e1552fac9129 | 1 + ..._serde-14-f8ba18cc7b0225b4022299c44d435101 | 1 + ..._serde-15-66fadc9bcea7d107a610758aa6f50ff3 | 0 ..._serde-16-1bd3345b46f77e17810978e56f9f7c6b | 0 ..._serde-17-a0df43062f8ab676ef728c9968443f12 | 0 ..._serde-18-b50ecc72ce9018ab12fb17568fef038a | 1 + ..._serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 | 1 + ..._serde-20-588516368d8c1533cb7bfb2157fd58c1 | 0 ..._serde-21-dfe166fe053468e738dca23ebe043091 | 0 ..._serde-22-45240a488fb708e432d2f45b74ef7e63 | 0 ..._serde-23-1742a51e4967a8d263572d890cd8d4a8 | 1 + ...e_serde-24-14fd49bd6fee907c1699f7b4e26685b | 1 + ..._serde-25-a199cf185184a25190d65c123d0694ee | 0 ..._serde-26-c5fa68d9aff36f22e5edc1b54332d0ab | 0 ..._serde-27-4d86c79f858866acec3c37f6598c2638 | 0 ..._serde-28-16a41fc9e0f51eb417c763bae8e9cadb | 1 + ..._serde-29-bd1cb09aacd906527b0bbf43bbded812 | 1 + ..._serde-30-7c80741f9f485729afc68609c55423a0 | 0 ...e_serde-31-da36cd1654aee055cb3650133c9d11f | 0 ..._serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 | 0 ..._serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 | 1 + ..._serde-34-6485841336c097895ad5b34f42c0745f | 1 + ..._serde-35-8651a7c351cbc07fb1af6193f6885de8 | 0 ..._serde-36-36e6041f53433482631018410bb62a99 | 0 ..._serde-37-3ddfd8ecb28991aeed588f1ea852c427 | 0 ..._serde-38-e6167e27465514356c557a77d956ea46 | 0 ..._serde-39-c1e17c93582656c12970c37bac153bf2 | 0 ..._serde-40-4a17944b9ec8999bb20c5ba5d4cb877c | 0 ...e_serde-8-cace4f60a08342f58fbe816a9c3a73cf | 137 ++++++++++++++++++ ...e_serde-9-436c3c61cc4278b54ac79c53c88ff422 | 12 ++ ...ate_udf-0-84604a42a5d7f2842f1eec10c689d447 | 0 ...ate_udf-1-5e8136f6a6503ae9bef9beca80fada13 | 0 ...te_udf-10-988ad9744096a29a3672a2d4c121299b | 1 + ...te_udf-11-a5100dd42201b5bc035a9d684cc21bdc | 1 + ...te_udf-12-eb7280a1f191344a99eaa0f805e8faff | 1 + ...te_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 | 1 + ...ate_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 | 1 + ...te_udf-15-d031ee50c119d7c6acafd53543dbd0c4 | 1 + ...te_udf-16-dc59f69e1685e8d923b187ec50d80f06 | 1 + ...te_udf-17-7d046d4efc568049cf3792470b6feab9 | 1 + ...te_udf-18-84604a42a5d7f2842f1eec10c689d447 | 0 ...te_udf-19-5e8136f6a6503ae9bef9beca80fada13 | 0 ...ate_udf-2-10e337c34d1e82a360b8599988f4b266 | 0 ...te_udf-20-10e337c34d1e82a360b8599988f4b266 | 0 ...ate_udf-3-29e406e613c0284b3e16a8943a4d31bd | 0 ...ate_udf-4-23653315213f578856ab5c3bd80c0264 | 0 ...ate_udf-5-891fd92a4787b9789f6d1f51c1eddc8a | 0 ...ate_udf-6-3473c118d20783eafb456043a2ee5d5b | 0 ...ate_udf-7-9fb5165824e161074565e7500959c1b2 | 0 ...ate_udf-8-badfe833681362092fc6345f888b1c21 | 1 + ...ate_udf-9-a8cbb039661d796beaa0d1564c58c563 | 1 + ...on_date-0-7ec1f3a845e2c49191460e15af30aa30 | 0 ...on_date-1-916193405ce5e020dcd32c58325db6fe | 0 ...n_date-10-a8dde9c0b5746dd770c9c262d23ffb10 | 1 + ...n_date-11-fdface2fb6eef67f15bb7d0de2294957 | 1 + ...n_date-12-9b945f8ece6e09ad28c866ff3a10cc24 | 1 + ...on_date-13-b7cb91c7c459798078a79071d329dbf | 1 + ...n_date-14-e4366325f3a0c4a8e92be59f4de73fce | 1 + ...n_date-15-a062a6e87867d8c8cfbdad97bedcbe5f | 1 + ...n_date-16-22a5627d9ac112665eae01d07a91c89c | 1 + ...on_date-17-b9ce94ef93cb16d629af7d7f8ee637e | 1 + ...n_date-18-72c6e9a4e0b434cef67144825346c687 | 1 + ...n_date-19-44e5165eb210559e420105073bc96125 | 1 + ...on_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 | 0 ...n_date-20-7ec1f3a845e2c49191460e15af30aa30 | 0 ...on_date-3-c938b08f57d588926a5d5fbfa4531012 | 0 ...on_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 | 0 ...on_date-5-a855aba47876561fd4fb095e09580686 | 0 ...on_date-6-1405c311915f27b0cc616c83d39eaacc | 2 + ...on_date-7-2ac950d8d5656549dd453e5464cb8530 | 5 + ...on_date-8-a425c11c12c9ce4c9c43d4fbccee5347 | 1 + ...on_date-9-aad6078a09b7bd8f5141437e86bb229f | 1 + ..._check-12-7e053ba4f9dea1e74c1d04c557c3adac | 6 + ..._check-13-45fb706ff448da1fe609c7ff76a80d4d | 0 ...on_date-6-f4d5c71145a9b7464685aa7d09cd4dfd | 40 +++++ ...on_date-7-a0bade1c77338d4f72962389a1f5bea2 | 0 ...on_date-8-21306adbd8be8ad75174ad9d3e42b73c | 0 150 files changed, 872 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java create mode 100644 sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb create mode 100644 sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b create mode 100644 sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 create mode 100644 sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 create mode 100644 sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 create mode 100644 sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea create mode 100644 sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b create mode 100644 sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b create mode 100644 sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb create mode 100644 sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a create mode 100644 sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b create mode 100644 sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 create mode 100644 sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d create mode 100644 sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b create mode 100644 sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 create mode 100644 sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d create mode 100644 sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b create mode 100644 sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba create mode 100644 sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 create mode 100644 sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c create mode 100644 sql/hive/src/test/resources/golden/date_2-6-f4edce7cb20f325e8b69e787b2ae8882 create mode 100644 sql/hive/src/test/resources/golden/date_3-3-4cf49e71b636df754871a675f9e4e24 create mode 100644 sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f create mode 100644 sql/hive/src/test/resources/golden/date_3-5-c26de4559926ddb0127d2dc5ea154774 create mode 100644 sql/hive/src/test/resources/golden/date_4-0-b84f7e931d710dcbe3c5126d998285a8 create mode 100644 sql/hive/src/test/resources/golden/date_4-1-6272f5e518f6a20bc96a5870ff315c4f create mode 100644 sql/hive/src/test/resources/golden/date_4-2-4a0e7bde447ef616b98e0f55d2886de0 create mode 100644 sql/hive/src/test/resources/golden/date_4-3-a23faa56b5d3ca9063a21f72b4278b00 create mode 100644 sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 create mode 100644 sql/hive/src/test/resources/golden/date_4-5-b84f7e931d710dcbe3c5126d998285a8 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c create mode 100644 sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 create mode 100644 sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb create mode 100644 sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c create mode 100644 sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c create mode 100644 sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 create mode 100644 sql/hive/src/test/resources/golden/date_join1-4-70b9b49c55699fe94cfde069f5d197c create mode 100644 sql/hive/src/test/resources/golden/date_serde-10-d80e681519dcd8f5078c5602bb5befa9 create mode 100644 sql/hive/src/test/resources/golden/date_serde-11-29540200936bba47f17553547b409af7 create mode 100644 sql/hive/src/test/resources/golden/date_serde-12-c3c3275658b89d31fc504db31ae9f99c create mode 100644 sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 create mode 100644 sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 create mode 100644 sql/hive/src/test/resources/golden/date_serde-15-66fadc9bcea7d107a610758aa6f50ff3 create mode 100644 sql/hive/src/test/resources/golden/date_serde-16-1bd3345b46f77e17810978e56f9f7c6b create mode 100644 sql/hive/src/test/resources/golden/date_serde-17-a0df43062f8ab676ef728c9968443f12 create mode 100644 sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a create mode 100644 sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 create mode 100644 sql/hive/src/test/resources/golden/date_serde-20-588516368d8c1533cb7bfb2157fd58c1 create mode 100644 sql/hive/src/test/resources/golden/date_serde-21-dfe166fe053468e738dca23ebe043091 create mode 100644 sql/hive/src/test/resources/golden/date_serde-22-45240a488fb708e432d2f45b74ef7e63 create mode 100644 sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 create mode 100644 sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b create mode 100644 sql/hive/src/test/resources/golden/date_serde-25-a199cf185184a25190d65c123d0694ee create mode 100644 sql/hive/src/test/resources/golden/date_serde-26-c5fa68d9aff36f22e5edc1b54332d0ab create mode 100644 sql/hive/src/test/resources/golden/date_serde-27-4d86c79f858866acec3c37f6598c2638 create mode 100644 sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb create mode 100644 sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 create mode 100644 sql/hive/src/test/resources/golden/date_serde-30-7c80741f9f485729afc68609c55423a0 create mode 100644 sql/hive/src/test/resources/golden/date_serde-31-da36cd1654aee055cb3650133c9d11f create mode 100644 sql/hive/src/test/resources/golden/date_serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 create mode 100644 sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 create mode 100644 sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f create mode 100644 sql/hive/src/test/resources/golden/date_serde-35-8651a7c351cbc07fb1af6193f6885de8 create mode 100644 sql/hive/src/test/resources/golden/date_serde-36-36e6041f53433482631018410bb62a99 create mode 100644 sql/hive/src/test/resources/golden/date_serde-37-3ddfd8ecb28991aeed588f1ea852c427 create mode 100644 sql/hive/src/test/resources/golden/date_serde-38-e6167e27465514356c557a77d956ea46 create mode 100644 sql/hive/src/test/resources/golden/date_serde-39-c1e17c93582656c12970c37bac153bf2 create mode 100644 sql/hive/src/test/resources/golden/date_serde-40-4a17944b9ec8999bb20c5ba5d4cb877c create mode 100644 sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf create mode 100644 sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 create mode 100644 sql/hive/src/test/resources/golden/date_udf-0-84604a42a5d7f2842f1eec10c689d447 create mode 100644 sql/hive/src/test/resources/golden/date_udf-1-5e8136f6a6503ae9bef9beca80fada13 create mode 100644 sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b create mode 100644 sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc create mode 100644 sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff create mode 100644 sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 create mode 100644 sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 create mode 100644 sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 create mode 100644 sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 create mode 100644 sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 create mode 100644 sql/hive/src/test/resources/golden/date_udf-18-84604a42a5d7f2842f1eec10c689d447 create mode 100644 sql/hive/src/test/resources/golden/date_udf-19-5e8136f6a6503ae9bef9beca80fada13 create mode 100644 sql/hive/src/test/resources/golden/date_udf-2-10e337c34d1e82a360b8599988f4b266 create mode 100644 sql/hive/src/test/resources/golden/date_udf-20-10e337c34d1e82a360b8599988f4b266 create mode 100644 sql/hive/src/test/resources/golden/date_udf-3-29e406e613c0284b3e16a8943a4d31bd create mode 100644 sql/hive/src/test/resources/golden/date_udf-4-23653315213f578856ab5c3bd80c0264 create mode 100644 sql/hive/src/test/resources/golden/date_udf-5-891fd92a4787b9789f6d1f51c1eddc8a create mode 100644 sql/hive/src/test/resources/golden/date_udf-6-3473c118d20783eafb456043a2ee5d5b create mode 100644 sql/hive/src/test/resources/golden/date_udf-7-9fb5165824e161074565e7500959c1b2 create mode 100644 sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 create mode 100644 sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 create mode 100644 sql/hive/src/test/resources/golden/partition_date-0-7ec1f3a845e2c49191460e15af30aa30 create mode 100644 sql/hive/src/test/resources/golden/partition_date-1-916193405ce5e020dcd32c58325db6fe create mode 100644 sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 create mode 100644 sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 create mode 100644 sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 create mode 100644 sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf create mode 100644 sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce create mode 100644 sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f create mode 100644 sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c create mode 100644 sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e create mode 100644 sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 create mode 100644 sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 create mode 100644 sql/hive/src/test/resources/golden/partition_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 create mode 100644 sql/hive/src/test/resources/golden/partition_date-20-7ec1f3a845e2c49191460e15af30aa30 create mode 100644 sql/hive/src/test/resources/golden/partition_date-3-c938b08f57d588926a5d5fbfa4531012 create mode 100644 sql/hive/src/test/resources/golden/partition_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 create mode 100644 sql/hive/src/test/resources/golden/partition_date-5-a855aba47876561fd4fb095e09580686 create mode 100644 sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc create mode 100644 sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 create mode 100644 sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 create mode 100644 sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f create mode 100644 sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac create mode 100644 sql/hive/src/test/resources/golden/partition_type_check-13-45fb706ff448da1fe609c7ff76a80d4d create mode 100644 sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd create mode 100644 sql/hive/src/test/resources/golden/union_date-7-a0bade1c77338d4f72962389a1f5bea2 create mode 100644 sql/hive/src/test/resources/golden/union_date-8-21306adbd8be8ad75174ad9d3e42b73c diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index b3ae8e6779700..3d4296f9d7068 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -77,8 +77,9 @@ object ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 64881854df7a5..7c480de107e7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -220,20 +220,39 @@ trait HiveTypeCoercion { case a: BinaryArithmetic if a.right.dataType == StringType => a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + // we should cast all timestamp/date/string compare into string compare + case p: BinaryPredicate if p.left.dataType == StringType + && p.right.dataType == DateType => + p.makeCopy(Array(p.left, Cast(p.right, StringType))) + case p: BinaryPredicate if p.left.dataType == DateType + && p.right.dataType == StringType => + p.makeCopy(Array(Cast(p.left, StringType), p.right)) case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) + p.makeCopy(Array(p.left, Cast(p.right, StringType))) case p: BinaryPredicate if p.left.dataType == TimestampType && p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) + p.makeCopy(Array(Cast(p.left, StringType), p.right)) + case p: BinaryPredicate if p.left.dataType == TimestampType + && p.right.dataType == DateType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) + case p: BinaryPredicate if p.left.dataType == DateType + && p.right.dataType == TimestampType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => - i.makeCopy(Array(a,b.map(Cast(_,TimestampType)))) + case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) => + i.makeCopy(Array(Cast(a, StringType), b)) + case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => + i.makeCopy(Array(Cast(a, StringType), b)) + case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e) if e.dataType == StringType => Sum(Cast(e, DoubleType)) @@ -283,6 +302,8 @@ trait HiveTypeCoercion { // Skip if the type is boolean type already. Note that this extra cast should be removed // by optimizer.SimplifyCasts. case Cast(e, BooleanType) if e.dataType == BooleanType => e + // DateType should be null if be cast to boolean. + case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) // If the data type is not boolean and is being cast boolean, turn it into a comparison // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index deb622c39faf5..75b6e37c2a1f9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.language.implicitConversions @@ -119,6 +119,7 @@ package object dsl { implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def dateToLiteral(d: Date) = Literal(d) implicit def decimalToLiteral(d: BigDecimal) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -174,6 +175,9 @@ package object dsl { /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = true)() + /** Creates a new AttributeReference of type date */ + def date = AttributeReference(s, DateType, nullable = true)() + /** Creates a new AttributeReference of type decimal */ def decimal = AttributeReference(s, DecimalType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f626d09f037bc..8e5ee12e314bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { override def foldable = child.foldable override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true + case (StringType, DateType) => true case _ => child.nullable } @@ -42,6 +45,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // UDFToString private[this] def castToString: Any => Any = child.dataType match { case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case DateType => buildCast[Date](_, dateToString) case TimestampType => buildCast[Timestamp](_, timestampToString) case _ => buildCast[Any](_, _.toString) } @@ -56,7 +60,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => buildCast[String](_, _.length() != 0) case TimestampType => - buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0) + buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) + case DateType => + // Hive would return null when cast from date to boolean + buildCast[Date](_, d => null) case LongType => buildCast[Long](_, _ != 0) case IntegerType => @@ -95,6 +102,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { buildCast[Short](_, s => new Timestamp(s)) case ByteType => buildCast[Byte](_, b => new Timestamp(b)) + case DateType => + buildCast[Date](_, d => new Timestamp(d.getTime)) // TimestampWritable.decimalToTimestamp case DecimalType => buildCast[BigDecimal](_, d => decimalToTimestamp(d)) @@ -130,7 +139,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // Converts Timestamp to string according to Hive TimestampWritable convention private[this] def timestampToString(ts: Timestamp): String = { val timestampString = ts.toString - val formatted = Cast.threadLocalDateFormat.get.format(ts) + val formatted = Cast.threadLocalTimestampFormat.get.format(ts) if (timestampString.length > 19 && timestampString.substring(19) != ".0") { formatted + timestampString.substring(19) @@ -139,6 +148,39 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } } + // Converts Timestamp to string according to Hive TimestampWritable convention + private[this] def timestampToDateString(ts: Timestamp): String = { + Cast.threadLocalDateFormat.get.format(ts) + } + + // DateConverter + private[this] def castToDate: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => + try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null } + ) + case TimestampType => + // throw valid precision more than seconds, according to Hive. + // Timestamp.nanos is in 0 to 999,999,999, no more than a second. + buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000)) + // Hive throws this exception as a Semantic Exception + // It is never possible to compare result when hive return with exception, so we can return null + // NULL is more reasonable here, since the query itself obeys the grammar. + case _ => _ => null + } + + // Date cannot be cast to long, according to hive + private[this] def dateToLong(d: Date) = null + + // Date cannot be cast to double, according to hive + private[this] def dateToDouble(d: Date) = null + + // Converts Date to string according to Hive DateWritable convention + private[this] def dateToString(d: Date): String = { + Cast.threadLocalDateFormat.get.format(d) + } + + // LongConverter private[this] def castToLong: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toLong catch { @@ -146,6 +188,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) case DecimalType => @@ -154,6 +198,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } + // IntConverter private[this] def castToInt: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toInt catch { @@ -161,6 +206,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) case DecimalType => @@ -169,6 +216,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } + // ShortConverter private[this] def castToShort: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toShort catch { @@ -176,6 +224,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) case DecimalType => @@ -184,6 +234,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } + // ByteConverter private[this] def castToByte: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toByte catch { @@ -191,6 +242,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) case DecimalType => @@ -199,6 +252,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } + // DecimalConverter private[this] def castToDecimal: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try BigDecimal(s.toDouble) catch { @@ -206,6 +260,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) @@ -213,6 +269,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) } + // DoubleConverter private[this] def castToDouble: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toDouble catch { @@ -220,6 +277,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) case DecimalType => @@ -228,6 +287,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } + // FloatConverter private[this] def castToFloat: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toFloat catch { @@ -235,6 +295,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) case DecimalType => @@ -245,17 +307,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] lazy val cast: Any => Any = dataType match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString - case BinaryType => castToBinary - case DecimalType => castToDecimal + case StringType => castToString + case BinaryType => castToBinary + case DecimalType => castToDecimal + case DateType => castToDate case TimestampType => castToTimestamp - case BooleanType => castToBoolean - case ByteType => castToByte - case ShortType => castToShort - case IntegerType => castToInt - case FloatType => castToFloat - case LongType => castToLong - case DoubleType => castToDouble + case BooleanType => castToBoolean + case ByteType => castToByte + case ShortType => castToShort + case IntegerType => castToInt + case FloatType => castToFloat + case LongType => castToLong + case DoubleType => castToDouble } override def eval(input: Row): Any = { @@ -267,6 +330,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { object Cast { // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue() = { + new SimpleDateFormat("yyyy-MM-dd") + } + } + + // `SimpleDateFormat` is not thread-safe. + private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue() = { new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 78a0c55e4bbe5..ba240233cae61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types._ @@ -33,6 +33,7 @@ object Literal { case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(d, DecimalType) case t: Timestamp => Literal(t, TimestampType) + case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 5bdacab664f8b..0cf139ebde417 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.types -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} import scala.reflect.ClassTag @@ -250,6 +250,16 @@ case object TimestampType extends NativeType { } } +case object DateType extends NativeType { + private[sql] type JvmType = Date + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Date, y: Date) = x.compareTo(y) + } +} + abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 692ed78a7292c..6dc5942023f9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet @@ -252,8 +252,11 @@ class ExpressionEvaluationSuite extends FunSuite { test("data type casting") { - val sts = "1970-01-01 00:00:01.1" - val ts = Timestamp.valueOf(sts) + val sd = "1970-01-01" + val d = Date.valueOf(sd) + val sts = sd + " 00:00:02" + val nts = sts + ".1" + val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") checkEvaluation("abdef" cast DecimalType, null) @@ -266,8 +269,15 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) + checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) + checkEvaluation(Cast(Literal(d) cast StringType, DateType), d) + checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) + // all convert to string type to check + checkEvaluation( + Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) + checkEvaluation( + Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), sts) checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") @@ -316,6 +326,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) } + test("date") { + val d1 = Date.valueOf("1970-01-01") + val d2 = Date.valueOf("1970-01-02") + checkEvaluation(Literal(d1) < Literal(d2), true) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -323,6 +339,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(ts1) < Literal(ts2), true) } + test("date casting") { + val d = Date.valueOf("1970-01-01") + checkEvaluation(Cast(d, ShortType), null) + checkEvaluation(Cast(d, IntegerType), null) + checkEvaluation(Cast(d, LongType), null) + checkEvaluation(Cast(d, FloatType), null) + checkEvaluation(Cast(d, DoubleType), null) + checkEvaluation(Cast(d, StringType), "1970-01-01") + checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + } + test("timestamp casting") { val millis = 15 * 1000 + 2 val seconds = millis * 1000 + 2 diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 37b4c8ffcba0b..37e88d72b9172 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -44,6 +44,11 @@ public abstract class DataType { */ public static final BooleanType BooleanType = new BooleanType(); + /** + * Gets the DateType object. + */ + public static final DateType DateType = new DateType(); + /** * Gets the TimestampType object. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java new file mode 100644 index 0000000000000..6677793baa365 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing java.sql.Date values. + * + * {@code DateType} is represented by the singleton object {@link DataType#DateType}. + */ +public class DateType extends DataType { + protected DateType() {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index c9faf0852142a..538dd5b734664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -92,6 +92,9 @@ private[sql] class FloatColumnAccessor(buffer: ByteBuffer) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, TIMESTAMP) @@ -118,6 +121,7 @@ private[sql] object ColumnAccessor { case BYTE.typeId => new ByteColumnAccessor(dup) case SHORT.typeId => new ShortColumnAccessor(dup) case STRING.typeId => new StringColumnAccessor(dup) + case DATE.typeId => new DateColumnAccessor(dup) case TIMESTAMP.typeId => new TimestampColumnAccessor(dup) case BINARY.typeId => new BinaryColumnAccessor(dup) case GENERIC.typeId => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 2e61a981375aa..300cef15bf8a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -107,6 +107,8 @@ private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColum private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) + private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) @@ -151,6 +153,7 @@ private[sql] object ColumnBuilder { case STRING.typeId => new StringColumnBuilder case BINARY.typeId => new BinaryColumnBuilder case GENERIC.typeId => new GenericColumnBuilder + case DATE.typeId => new DateColumnBuilder case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 203a714e03c97..b34ab255d084a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} @@ -190,6 +190,24 @@ private[sql] class StringColumnStats extends ColumnStats { def collectedStatistics = Row(lower, upper, nullCount) } +private[sql] class DateColumnStats extends ColumnStats { + var upper: Date = null + var lower: Date = null + var nullCount = 0 + + override def gatherStats(row: Row, ordinal: Int) { + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Date] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 + } + } + + def collectedStatistics = Row(lower, upper, nullCount) +} + private[sql] class TimestampColumnStats extends ColumnStats { var upper: Timestamp = null var lower: Timestamp = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 198b5756676aa..ab66c85c4f242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.TypeTag @@ -335,7 +335,26 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } } -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { +private[sql] object DATE extends NativeColumnType(DateType, 8, 8) { + override def extract(buffer: ByteBuffer) = { + val date = new Date(buffer.getLong()) + date + } + + override def append(v: Date, buffer: ByteBuffer): Unit = { + buffer.putLong(v.getTime) + } + + override def getField(row: Row, ordinal: Int) = { + row(ordinal).asInstanceOf[Date] + } + + override def setField(row: MutableRow, ordinal: Int, value: Date): Unit = { + row(ordinal) = value + } +} + +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { override def extract(buffer: ByteBuffer) = { val timestamp = new Timestamp(buffer.getLong()) timestamp.setNanos(buffer.getInt()) @@ -376,7 +395,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { +private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } @@ -387,7 +406,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { +private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } @@ -407,6 +426,7 @@ private[sql] object ColumnType { case ShortType => SHORT case StringType => STRING case BinaryType => BINARY + case DateType => DATE case TimestampType => TIMESTAMP case _ => GENERIC } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index f513eae9c2d13..e98d151286818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -165,6 +165,16 @@ package object sql { @DeveloperApi val TimestampType = catalyst.types.TimestampType + /** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Date` values. + * + * @group dataType + */ + @DeveloperApi + val DateType = catalyst.types.DateType + /** * :: DeveloperApi :: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 77353f4eb0227..e44cb08309523 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -41,6 +41,7 @@ protected[sql] object DataTypeConversions { case StringType => JDataType.StringType case BinaryType => JDataType.BinaryType case BooleanType => JDataType.BooleanType + case DateType => JDataType.DateType case TimestampType => JDataType.TimestampType case DecimalType => JDataType.DecimalType case DoubleType => JDataType.DoubleType @@ -80,6 +81,8 @@ protected[sql] object DataTypeConversions { BinaryType case booleanType: org.apache.spark.sql.api.java.BooleanType => BooleanType + case dateType: org.apache.spark.sql.api.java.DateType => + DateType case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType case decimalType: org.apache.spark.sql.api.java.DecimalType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index e24c521d24c7a..bfa9ea416266d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite @@ -34,6 +34,7 @@ case class ReflectData( byteField: Byte, booleanField: Boolean, decimalField: BigDecimal, + date: Date, timestampField: Timestamp, seqInt: Seq[Int]) @@ -76,7 +77,7 @@ case class ComplexReflectData( class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) + BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 0cdbb3167ce36..6bdf741134e2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -30,6 +30,7 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) + testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) def testColumnStats[T <: NativeType, U <: ColumnStats]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 4fb1ecf1d532b..3f3f35d50188b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite @@ -33,8 +33,8 @@ class ColumnTypeSuite extends FunSuite with Logging { test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - BOOLEAN -> 1, STRING -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) + INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1, + STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -64,6 +64,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(BOOLEAN, true, 1) checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(DATE, new Date(0L), 8) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 38b04dd959f70..a1f21219eaf2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar import scala.collection.immutable.HashSet import scala.util.Random -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -50,6 +50,7 @@ object ColumnarTestUtils { case STRING => Random.nextString(Random.nextInt(32)) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) + case DATE => new Date(Random.nextLong()) case TIMESTAMP => val timestamp = new Timestamp(Random.nextLong()) timestamp.setNanos(Random.nextInt(999999999)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 6c9a9ab6c3418..21906e3fdcc6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -41,7 +41,9 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { + Seq( + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + ).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index f54a21eb4fbb1..cb73f3da81e24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -37,7 +37,9 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { + Seq( + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + ).foreach { testNullableColumnBuilder(_) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 35e9c9939d4b7..463888551a359 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -343,6 +343,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ct_case_insensitive", "database_location", "database_properties", + "date_2", + "date_3", + "date_4", + "date_comparison", + "date_join1", + "date_serde", + "date_udf", "decimal_1", "decimal_4", "decimal_join", @@ -604,8 +611,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "part_inherit_tbl_props", "part_inherit_tbl_props_empty", "part_inherit_tbl_props_with_star", + "partition_date", "partition_schema1", "partition_serde_format", + "partition_type_check", "partition_varchar1", "partition_wise_fileformat4", "partition_wise_fileformat5", @@ -904,6 +913,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "union7", "union8", "union9", + "union_date", "union_lateralview", "union_ppr", "union_remove_11", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fad3b39f81413..8b5a90159e1bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.io.TimestampWritable +import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -357,7 +358,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, TimestampType, BinaryType) + ShortType, DecimalType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -372,6 +373,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" + case (d: Date, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (other, tpe) if primitiveTypes contains tpe => other.toString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index d633c42c6bd67..1977618b4c9f2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -39,6 +39,7 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType case c: Class[_] if c == classOf[hadoopIo.Text] => StringType case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType @@ -49,6 +50,7 @@ private[hive] trait HiveInspectors { // java class case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType case c: Class[_] if c == classOf[HiveDecimal] => DecimalType case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType @@ -93,6 +95,7 @@ private[hive] trait HiveInspectors { System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) bytes } + case d: hiveIo.DateWritable => d.get case t: hiveIo.TimestampWritable => t.getTimestamp case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) case list: java.util.List[_] => list.map(unwrap) @@ -108,6 +111,7 @@ private[hive] trait HiveInspectors { case str: String => str case p: java.math.BigDecimal => p case p: Array[Byte] => p + case p: java.sql.Date => p case p: java.sql.Timestamp => p } @@ -147,6 +151,7 @@ private[hive] trait HiveInspectors { case l: Byte => l: java.lang.Byte case b: BigDecimal => new HiveDecimal(b.underlying()) case b: Array[Byte] => b + case d: java.sql.Date => d case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) case m: Map[_,_] => @@ -173,6 +178,7 @@ private[hive] trait HiveInspectors { case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => @@ -211,6 +217,8 @@ private[hive] trait HiveInspectors { case _: JavaBinaryObjectInspector => BinaryType case _: WritableHiveDecimalObjectInspector => DecimalType case _: JavaHiveDecimalObjectInspector => DecimalType + case _: WritableDateObjectInspector => DateType + case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType case _: JavaTimestampObjectInspector => TimestampType case _: WritableVoidObjectInspector => NullType @@ -238,6 +246,7 @@ private[hive] trait HiveInspectors { case ShortType => shortTypeInfo case StringType => stringTypeInfo case DecimalType => decimalTypeInfo + case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index addd5bed8426d..c5fee5e4702f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -186,6 +186,7 @@ object HiveMetastoreTypes extends RegexParsers { "binary" ^^^ BinaryType | "boolean" ^^^ BooleanType | "decimal" ^^^ DecimalType | + "date" ^^^ DateType | "timestamp" ^^^ TimestampType | "varchar\\((\\d+)\\)".r ^^^ StringType @@ -235,6 +236,7 @@ object HiveMetastoreTypes extends RegexParsers { case LongType => "bigint" case BinaryType => "binary" case BooleanType => "boolean" + case DateType => "date" case DecimalType => "decimal" case TimestampType => "timestamp" case NullType => "void" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7cc14dc7a9c9e..2b599157d15d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.sql.Date + import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils @@ -317,6 +319,7 @@ private[hive] object HiveQl { case Token("TOK_STRING", Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_DATE", Nil) => DateType case Token("TOK_TIMESTAMP", Nil) => TimestampType case Token("TOK_BINARY", Nil) => BinaryType case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) @@ -924,6 +927,8 @@ private[hive] object HiveQl { Cast(nodeToExpr(arg), DecimalType) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) + case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DateType) /* Arithmetic */ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) @@ -1047,6 +1052,9 @@ private[hive] object HiveQl { case ast: ASTNode if ast.getType == HiveParser.StringLiteral => Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : diff --git a/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b b/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 b/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 new file mode 100644 index 0000000000000..5a368ab170261 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 @@ -0,0 +1 @@ +2012-01-01 2011-01-01 2011-01-01 00:00:00 2011-01-01 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 b/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 new file mode 100644 index 0000000000000..edb4b1f84001b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 @@ -0,0 +1 @@ +NULL NULL NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea b/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea new file mode 100644 index 0000000000000..2af0b9ed3a68c --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea @@ -0,0 +1 @@ +true true true true true true true true true true diff --git a/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b b/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b new file mode 100644 index 0000000000000..d8dfbf60007bd --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b @@ -0,0 +1 @@ +2001-01-28 2001-02-28 2001-03-28 2001-04-28 2001-05-28 2001-06-28 2001-07-28 2001-08-28 2001-09-28 2001-10-28 2001-11-28 2001-12-28 diff --git a/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b b/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b new file mode 100644 index 0000000000000..4f6a1bc4273e0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b @@ -0,0 +1 @@ +true true true true true true true true true true true true diff --git a/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a b/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d b/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d b/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba b/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba new file mode 100644 index 0000000000000..db973ab292d5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba @@ -0,0 +1,137 @@ +2010-10-20 7291 +2010-10-20 3198 +2010-10-20 3014 +2010-10-20 2630 +2010-10-20 1610 +2010-10-20 1599 +2010-10-20 1531 +2010-10-20 1142 +2010-10-20 1064 +2010-10-20 897 +2010-10-20 361 +2010-10-21 7291 +2010-10-21 3198 +2010-10-21 3014 +2010-10-21 2646 +2010-10-21 2630 +2010-10-21 1610 +2010-10-21 1599 +2010-10-21 1531 +2010-10-21 1142 +2010-10-21 1064 +2010-10-21 897 +2010-10-21 361 +2010-10-22 3198 +2010-10-22 3014 +2010-10-22 2646 +2010-10-22 2630 +2010-10-22 1610 +2010-10-22 1599 +2010-10-22 1531 +2010-10-22 1142 +2010-10-22 1064 +2010-10-22 897 +2010-10-22 361 +2010-10-23 7274 +2010-10-23 5917 +2010-10-23 5904 +2010-10-23 5832 +2010-10-23 3171 +2010-10-23 3085 +2010-10-23 2932 +2010-10-23 1805 +2010-10-23 650 +2010-10-23 426 +2010-10-23 384 +2010-10-23 272 +2010-10-24 7282 +2010-10-24 3198 +2010-10-24 3014 +2010-10-24 2646 +2010-10-24 2630 +2010-10-24 2571 +2010-10-24 2254 +2010-10-24 1610 +2010-10-24 1599 +2010-10-24 1531 +2010-10-24 897 +2010-10-24 361 +2010-10-25 7291 +2010-10-25 3198 +2010-10-25 3014 +2010-10-25 2646 +2010-10-25 2630 +2010-10-25 1610 +2010-10-25 1599 +2010-10-25 1531 +2010-10-25 1142 +2010-10-25 1064 +2010-10-25 897 +2010-10-25 361 +2010-10-26 7291 +2010-10-26 3198 +2010-10-26 3014 +2010-10-26 2662 +2010-10-26 2646 +2010-10-26 2630 +2010-10-26 1610 +2010-10-26 1599 +2010-10-26 1531 +2010-10-26 1142 +2010-10-26 1064 +2010-10-26 897 +2010-10-26 361 +2010-10-27 7291 +2010-10-27 3198 +2010-10-27 3014 +2010-10-27 2630 +2010-10-27 1610 +2010-10-27 1599 +2010-10-27 1531 +2010-10-27 1142 +2010-10-27 1064 +2010-10-27 897 +2010-10-27 361 +2010-10-28 7291 +2010-10-28 3198 +2010-10-28 3014 +2010-10-28 2646 +2010-10-28 2630 +2010-10-28 1610 +2010-10-28 1599 +2010-10-28 1531 +2010-10-28 1142 +2010-10-28 1064 +2010-10-28 897 +2010-10-28 361 +2010-10-29 7291 +2010-10-29 3198 +2010-10-29 3014 +2010-10-29 2646 +2010-10-29 2630 +2010-10-29 1610 +2010-10-29 1599 +2010-10-29 1531 +2010-10-29 1142 +2010-10-29 1064 +2010-10-29 897 +2010-10-29 361 +2010-10-30 5917 +2010-10-30 5904 +2010-10-30 3171 +2010-10-30 3085 +2010-10-30 2932 +2010-10-30 2018 +2010-10-30 1805 +2010-10-30 650 +2010-10-30 426 +2010-10-30 384 +2010-10-30 272 +2010-10-31 7282 +2010-10-31 3198 +2010-10-31 2571 +2010-10-31 1610 +2010-10-31 1599 +2010-10-31 1531 +2010-10-31 897 +2010-10-31 361 diff --git a/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 b/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 new file mode 100644 index 0000000000000..1b0ea7b9eec84 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 @@ -0,0 +1,137 @@ +2010-10-31 361 +2010-10-31 897 +2010-10-31 1531 +2010-10-31 1599 +2010-10-31 1610 +2010-10-31 2571 +2010-10-31 3198 +2010-10-31 7282 +2010-10-30 272 +2010-10-30 384 +2010-10-30 426 +2010-10-30 650 +2010-10-30 1805 +2010-10-30 2018 +2010-10-30 2932 +2010-10-30 3085 +2010-10-30 3171 +2010-10-30 5904 +2010-10-30 5917 +2010-10-29 361 +2010-10-29 897 +2010-10-29 1064 +2010-10-29 1142 +2010-10-29 1531 +2010-10-29 1599 +2010-10-29 1610 +2010-10-29 2630 +2010-10-29 2646 +2010-10-29 3014 +2010-10-29 3198 +2010-10-29 7291 +2010-10-28 361 +2010-10-28 897 +2010-10-28 1064 +2010-10-28 1142 +2010-10-28 1531 +2010-10-28 1599 +2010-10-28 1610 +2010-10-28 2630 +2010-10-28 2646 +2010-10-28 3014 +2010-10-28 3198 +2010-10-28 7291 +2010-10-27 361 +2010-10-27 897 +2010-10-27 1064 +2010-10-27 1142 +2010-10-27 1531 +2010-10-27 1599 +2010-10-27 1610 +2010-10-27 2630 +2010-10-27 3014 +2010-10-27 3198 +2010-10-27 7291 +2010-10-26 361 +2010-10-26 897 +2010-10-26 1064 +2010-10-26 1142 +2010-10-26 1531 +2010-10-26 1599 +2010-10-26 1610 +2010-10-26 2630 +2010-10-26 2646 +2010-10-26 2662 +2010-10-26 3014 +2010-10-26 3198 +2010-10-26 7291 +2010-10-25 361 +2010-10-25 897 +2010-10-25 1064 +2010-10-25 1142 +2010-10-25 1531 +2010-10-25 1599 +2010-10-25 1610 +2010-10-25 2630 +2010-10-25 2646 +2010-10-25 3014 +2010-10-25 3198 +2010-10-25 7291 +2010-10-24 361 +2010-10-24 897 +2010-10-24 1531 +2010-10-24 1599 +2010-10-24 1610 +2010-10-24 2254 +2010-10-24 2571 +2010-10-24 2630 +2010-10-24 2646 +2010-10-24 3014 +2010-10-24 3198 +2010-10-24 7282 +2010-10-23 272 +2010-10-23 384 +2010-10-23 426 +2010-10-23 650 +2010-10-23 1805 +2010-10-23 2932 +2010-10-23 3085 +2010-10-23 3171 +2010-10-23 5832 +2010-10-23 5904 +2010-10-23 5917 +2010-10-23 7274 +2010-10-22 361 +2010-10-22 897 +2010-10-22 1064 +2010-10-22 1142 +2010-10-22 1531 +2010-10-22 1599 +2010-10-22 1610 +2010-10-22 2630 +2010-10-22 2646 +2010-10-22 3014 +2010-10-22 3198 +2010-10-21 361 +2010-10-21 897 +2010-10-21 1064 +2010-10-21 1142 +2010-10-21 1531 +2010-10-21 1599 +2010-10-21 1610 +2010-10-21 2630 +2010-10-21 2646 +2010-10-21 3014 +2010-10-21 3198 +2010-10-21 7291 +2010-10-20 361 +2010-10-20 897 +2010-10-20 1064 +2010-10-20 1142 +2010-10-20 1531 +2010-10-20 1599 +2010-10-20 1610 +2010-10-20 2630 +2010-10-20 3014 +2010-10-20 3198 +2010-10-20 7291 diff --git a/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c b/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c new file mode 100644 index 0000000000000..0f2a6f7a99237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c @@ -0,0 +1,12 @@ +2010-10-20 11 +2010-10-21 12 +2010-10-22 11 +2010-10-23 12 +2010-10-24 12 +2010-10-25 12 +2010-10-26 13 +2010-10-27 11 +2010-10-28 12 +2010-10-29 12 +2010-10-30 11 +2010-10-31 8 diff --git a/sql/hive/src/test/resources/golden/date_2-6-f4edce7cb20f325e8b69e787b2ae8882 b/sql/hive/src/test/resources/golden/date_2-6-f4edce7cb20f325e8b69e787b2ae8882 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_3-3-4cf49e71b636df754871a675f9e4e24 b/sql/hive/src/test/resources/golden/date_3-3-4cf49e71b636df754871a675f9e4e24 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f b/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f new file mode 100644 index 0000000000000..66d2220d06de2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f @@ -0,0 +1 @@ +1 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_3-5-c26de4559926ddb0127d2dc5ea154774 b/sql/hive/src/test/resources/golden/date_3-5-c26de4559926ddb0127d2dc5ea154774 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-0-b84f7e931d710dcbe3c5126d998285a8 b/sql/hive/src/test/resources/golden/date_4-0-b84f7e931d710dcbe3c5126d998285a8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-1-6272f5e518f6a20bc96a5870ff315c4f b/sql/hive/src/test/resources/golden/date_4-1-6272f5e518f6a20bc96a5870ff315c4f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-2-4a0e7bde447ef616b98e0f55d2886de0 b/sql/hive/src/test/resources/golden/date_4-2-4a0e7bde447ef616b98e0f55d2886de0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-3-a23faa56b5d3ca9063a21f72b4278b00 b/sql/hive/src/test/resources/golden/date_4-3-a23faa56b5d3ca9063a21f72b4278b00 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 b/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 new file mode 100644 index 0000000000000..b61affde4ffce --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 @@ -0,0 +1 @@ +2011-01-01 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_4-5-b84f7e931d710dcbe3c5126d998285a8 b/sql/hive/src/test/resources/golden/date_4-5-b84f7e931d710dcbe3c5126d998285a8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 b/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 b/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 b/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 b/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 b/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 b/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 b/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c b/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 b/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb b/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c b/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c b/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 b/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 new file mode 100644 index 0000000000000..b7305b903edca --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 @@ -0,0 +1,22 @@ +1064 2010-10-20 1064 2010-10-20 +1142 2010-10-21 1142 2010-10-21 +1599 2010-10-22 1599 2010-10-22 +361 2010-10-23 361 2010-10-23 +897 2010-10-24 897 2010-10-24 +1531 2010-10-25 1531 2010-10-25 +1610 2010-10-26 1610 2010-10-26 +3198 2010-10-27 3198 2010-10-27 +1064 2010-10-28 1064 2010-10-28 +1142 2010-10-29 1142 2010-10-29 +1064 2000-11-20 1064 2000-11-20 +1142 2000-11-21 1142 2000-11-21 +1599 2000-11-22 1599 2000-11-22 +361 2000-11-23 361 2000-11-23 +897 2000-11-24 897 2000-11-24 +1531 2000-11-25 1531 2000-11-25 +1610 2000-11-26 1610 2000-11-26 +3198 2000-11-27 3198 2000-11-27 +1064 2000-11-28 1064 2000-11-28 +1142 2000-11-28 1064 2000-11-28 +1064 2000-11-28 1142 2000-11-28 +1142 2000-11-28 1142 2000-11-28 diff --git a/sql/hive/src/test/resources/golden/date_join1-4-70b9b49c55699fe94cfde069f5d197c b/sql/hive/src/test/resources/golden/date_join1-4-70b9b49c55699fe94cfde069f5d197c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-10-d80e681519dcd8f5078c5602bb5befa9 b/sql/hive/src/test/resources/golden/date_serde-10-d80e681519dcd8f5078c5602bb5befa9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-11-29540200936bba47f17553547b409af7 b/sql/hive/src/test/resources/golden/date_serde-11-29540200936bba47f17553547b409af7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-12-c3c3275658b89d31fc504db31ae9f99c b/sql/hive/src/test/resources/golden/date_serde-12-c3c3275658b89d31fc504db31ae9f99c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 b/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 b/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-15-66fadc9bcea7d107a610758aa6f50ff3 b/sql/hive/src/test/resources/golden/date_serde-15-66fadc9bcea7d107a610758aa6f50ff3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-16-1bd3345b46f77e17810978e56f9f7c6b b/sql/hive/src/test/resources/golden/date_serde-16-1bd3345b46f77e17810978e56f9f7c6b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-17-a0df43062f8ab676ef728c9968443f12 b/sql/hive/src/test/resources/golden/date_serde-17-a0df43062f8ab676ef728c9968443f12 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a b/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 b/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-20-588516368d8c1533cb7bfb2157fd58c1 b/sql/hive/src/test/resources/golden/date_serde-20-588516368d8c1533cb7bfb2157fd58c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-21-dfe166fe053468e738dca23ebe043091 b/sql/hive/src/test/resources/golden/date_serde-21-dfe166fe053468e738dca23ebe043091 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-22-45240a488fb708e432d2f45b74ef7e63 b/sql/hive/src/test/resources/golden/date_serde-22-45240a488fb708e432d2f45b74ef7e63 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 b/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b b/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-25-a199cf185184a25190d65c123d0694ee b/sql/hive/src/test/resources/golden/date_serde-25-a199cf185184a25190d65c123d0694ee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-26-c5fa68d9aff36f22e5edc1b54332d0ab b/sql/hive/src/test/resources/golden/date_serde-26-c5fa68d9aff36f22e5edc1b54332d0ab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-27-4d86c79f858866acec3c37f6598c2638 b/sql/hive/src/test/resources/golden/date_serde-27-4d86c79f858866acec3c37f6598c2638 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb b/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 b/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-30-7c80741f9f485729afc68609c55423a0 b/sql/hive/src/test/resources/golden/date_serde-30-7c80741f9f485729afc68609c55423a0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-31-da36cd1654aee055cb3650133c9d11f b/sql/hive/src/test/resources/golden/date_serde-31-da36cd1654aee055cb3650133c9d11f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 b/sql/hive/src/test/resources/golden/date_serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 b/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f b/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-35-8651a7c351cbc07fb1af6193f6885de8 b/sql/hive/src/test/resources/golden/date_serde-35-8651a7c351cbc07fb1af6193f6885de8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-36-36e6041f53433482631018410bb62a99 b/sql/hive/src/test/resources/golden/date_serde-36-36e6041f53433482631018410bb62a99 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-37-3ddfd8ecb28991aeed588f1ea852c427 b/sql/hive/src/test/resources/golden/date_serde-37-3ddfd8ecb28991aeed588f1ea852c427 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-38-e6167e27465514356c557a77d956ea46 b/sql/hive/src/test/resources/golden/date_serde-38-e6167e27465514356c557a77d956ea46 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-39-c1e17c93582656c12970c37bac153bf2 b/sql/hive/src/test/resources/golden/date_serde-39-c1e17c93582656c12970c37bac153bf2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-40-4a17944b9ec8999bb20c5ba5d4cb877c b/sql/hive/src/test/resources/golden/date_serde-40-4a17944b9ec8999bb20c5ba5d4cb877c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf b/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf new file mode 100644 index 0000000000000..16c03e7276fec --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf @@ -0,0 +1,137 @@ +Baltimore New York 2010-10-20 -30.0 1064 +Baltimore New York 2010-10-20 23.0 1142 +Baltimore New York 2010-10-20 6.0 1599 +Chicago New York 2010-10-20 42.0 361 +Chicago New York 2010-10-20 24.0 897 +Chicago New York 2010-10-20 15.0 1531 +Chicago New York 2010-10-20 -6.0 1610 +Chicago New York 2010-10-20 -2.0 3198 +Baltimore New York 2010-10-21 17.0 1064 +Baltimore New York 2010-10-21 105.0 1142 +Baltimore New York 2010-10-21 28.0 1599 +Chicago New York 2010-10-21 142.0 361 +Chicago New York 2010-10-21 77.0 897 +Chicago New York 2010-10-21 53.0 1531 +Chicago New York 2010-10-21 -5.0 1610 +Chicago New York 2010-10-21 51.0 3198 +Baltimore New York 2010-10-22 -12.0 1064 +Baltimore New York 2010-10-22 54.0 1142 +Baltimore New York 2010-10-22 18.0 1599 +Chicago New York 2010-10-22 2.0 361 +Chicago New York 2010-10-22 24.0 897 +Chicago New York 2010-10-22 16.0 1531 +Chicago New York 2010-10-22 -6.0 1610 +Chicago New York 2010-10-22 -11.0 3198 +Baltimore New York 2010-10-23 18.0 272 +Baltimore New York 2010-10-23 -10.0 1805 +Baltimore New York 2010-10-23 6.0 3171 +Chicago New York 2010-10-23 3.0 384 +Chicago New York 2010-10-23 32.0 426 +Chicago New York 2010-10-23 1.0 650 +Chicago New York 2010-10-23 11.0 3085 +Baltimore New York 2010-10-24 12.0 1599 +Baltimore New York 2010-10-24 20.0 2571 +Chicago New York 2010-10-24 10.0 361 +Chicago New York 2010-10-24 113.0 897 +Chicago New York 2010-10-24 -5.0 1531 +Chicago New York 2010-10-24 -17.0 1610 +Chicago New York 2010-10-24 -3.0 3198 +Baltimore New York 2010-10-25 -25.0 1064 +Baltimore New York 2010-10-25 92.0 1142 +Baltimore New York 2010-10-25 106.0 1599 +Chicago New York 2010-10-25 31.0 361 +Chicago New York 2010-10-25 -1.0 897 +Chicago New York 2010-10-25 43.0 1531 +Chicago New York 2010-10-25 6.0 1610 +Chicago New York 2010-10-25 -16.0 3198 +Baltimore New York 2010-10-26 -22.0 1064 +Baltimore New York 2010-10-26 123.0 1142 +Baltimore New York 2010-10-26 90.0 1599 +Chicago New York 2010-10-26 12.0 361 +Chicago New York 2010-10-26 0.0 897 +Chicago New York 2010-10-26 29.0 1531 +Chicago New York 2010-10-26 -17.0 1610 +Chicago New York 2010-10-26 6.0 3198 +Baltimore New York 2010-10-27 -18.0 1064 +Baltimore New York 2010-10-27 49.0 1142 +Baltimore New York 2010-10-27 92.0 1599 +Chicago New York 2010-10-27 148.0 361 +Chicago New York 2010-10-27 -11.0 897 +Chicago New York 2010-10-27 70.0 1531 +Chicago New York 2010-10-27 8.0 1610 +Chicago New York 2010-10-27 21.0 3198 +Baltimore New York 2010-10-28 -4.0 1064 +Baltimore New York 2010-10-28 -14.0 1142 +Baltimore New York 2010-10-28 -14.0 1599 +Chicago New York 2010-10-28 2.0 361 +Chicago New York 2010-10-28 2.0 897 +Chicago New York 2010-10-28 -11.0 1531 +Chicago New York 2010-10-28 3.0 1610 +Chicago New York 2010-10-28 -18.0 3198 +Baltimore New York 2010-10-29 -24.0 1064 +Baltimore New York 2010-10-29 21.0 1142 +Baltimore New York 2010-10-29 -2.0 1599 +Chicago New York 2010-10-29 -12.0 361 +Chicago New York 2010-10-29 -11.0 897 +Chicago New York 2010-10-29 15.0 1531 +Chicago New York 2010-10-29 -18.0 1610 +Chicago New York 2010-10-29 -4.0 3198 +Baltimore New York 2010-10-30 14.0 272 +Baltimore New York 2010-10-30 -1.0 1805 +Baltimore New York 2010-10-30 5.0 3171 +Chicago New York 2010-10-30 -6.0 384 +Chicago New York 2010-10-30 -10.0 426 +Chicago New York 2010-10-30 -5.0 650 +Chicago New York 2010-10-30 -5.0 3085 +Baltimore New York 2010-10-31 -1.0 1599 +Baltimore New York 2010-10-31 -14.0 2571 +Chicago New York 2010-10-31 -25.0 361 +Chicago New York 2010-10-31 -18.0 897 +Chicago New York 2010-10-31 -4.0 1531 +Chicago New York 2010-10-31 -22.0 1610 +Chicago New York 2010-10-31 -15.0 3198 +Cleveland New York 2010-10-30 -23.0 2018 +Cleveland New York 2010-10-30 -12.0 2932 +Cleveland New York 2010-10-29 -4.0 2630 +Cleveland New York 2010-10-29 -19.0 2646 +Cleveland New York 2010-10-29 -12.0 3014 +Cleveland New York 2010-10-28 3.0 2630 +Cleveland New York 2010-10-28 -6.0 2646 +Cleveland New York 2010-10-28 1.0 3014 +Cleveland New York 2010-10-27 16.0 2630 +Cleveland New York 2010-10-27 27.0 3014 +Cleveland New York 2010-10-26 4.0 2630 +Cleveland New York 2010-10-26 -27.0 2646 +Cleveland New York 2010-10-26 -11.0 2662 +Cleveland New York 2010-10-26 13.0 3014 +Cleveland New York 2010-10-25 -4.0 2630 +Cleveland New York 2010-10-25 81.0 2646 +Cleveland New York 2010-10-25 42.0 3014 +Cleveland New York 2010-10-24 5.0 2254 +Cleveland New York 2010-10-24 -11.0 2630 +Cleveland New York 2010-10-24 -20.0 2646 +Cleveland New York 2010-10-24 -9.0 3014 +Cleveland New York 2010-10-23 -21.0 2932 +Cleveland New York 2010-10-22 1.0 2630 +Cleveland New York 2010-10-22 -25.0 2646 +Cleveland New York 2010-10-22 -3.0 3014 +Cleveland New York 2010-10-21 3.0 2630 +Cleveland New York 2010-10-21 29.0 2646 +Cleveland New York 2010-10-21 72.0 3014 +Cleveland New York 2010-10-20 -8.0 2630 +Cleveland New York 2010-10-20 -15.0 3014 +Washington New York 2010-10-23 -25.0 5832 +Washington New York 2010-10-23 -21.0 5904 +Washington New York 2010-10-23 -18.0 5917 +Washington New York 2010-10-30 -27.0 5904 +Washington New York 2010-10-30 -16.0 5917 +Washington New York 2010-10-20 -2.0 7291 +Washington New York 2010-10-21 22.0 7291 +Washington New York 2010-10-23 -16.0 7274 +Washington New York 2010-10-24 -26.0 7282 +Washington New York 2010-10-25 9.0 7291 +Washington New York 2010-10-26 4.0 7291 +Washington New York 2010-10-27 26.0 7291 +Washington New York 2010-10-28 45.0 7291 +Washington New York 2010-10-29 1.0 7291 +Washington New York 2010-10-31 -18.0 7282 diff --git a/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 b/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 new file mode 100644 index 0000000000000..0f2a6f7a99237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 @@ -0,0 +1,12 @@ +2010-10-20 11 +2010-10-21 12 +2010-10-22 11 +2010-10-23 12 +2010-10-24 12 +2010-10-25 12 +2010-10-26 13 +2010-10-27 11 +2010-10-28 12 +2010-10-29 12 +2010-10-30 11 +2010-10-31 8 diff --git a/sql/hive/src/test/resources/golden/date_udf-0-84604a42a5d7f2842f1eec10c689d447 b/sql/hive/src/test/resources/golden/date_udf-0-84604a42a5d7f2842f1eec10c689d447 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-1-5e8136f6a6503ae9bef9beca80fada13 b/sql/hive/src/test/resources/golden/date_udf-1-5e8136f6a6503ae9bef9beca80fada13 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b b/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b new file mode 100644 index 0000000000000..83c33400edb47 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b @@ -0,0 +1 @@ +0 3333 -3333 -3332 3332 diff --git a/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc b/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc new file mode 100644 index 0000000000000..4a2462bb3929b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc @@ -0,0 +1 @@ +NULL 2011 5 6 6 18 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff b/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 b/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 new file mode 100644 index 0000000000000..977f0d24c58cc --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 @@ -0,0 +1 @@ +0 3333 -3333 -3333 3333 diff --git a/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 b/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 new file mode 100644 index 0000000000000..44d1f45e4eb73 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 @@ -0,0 +1 @@ +1970-01-01 08:00:00 1969-12-31 16:00:00 2013-06-19 07:00:00 2013-06-18 17:00:00 diff --git a/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 b/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 new file mode 100644 index 0000000000000..645b71d8d61e7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 @@ -0,0 +1 @@ +true true true true diff --git a/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 b/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 new file mode 100644 index 0000000000000..51863e9a14e4b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 @@ -0,0 +1 @@ +2010-10-20 diff --git a/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 b/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 new file mode 100644 index 0000000000000..4043ee1cbdd40 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 @@ -0,0 +1 @@ +2010-10-31 diff --git a/sql/hive/src/test/resources/golden/date_udf-18-84604a42a5d7f2842f1eec10c689d447 b/sql/hive/src/test/resources/golden/date_udf-18-84604a42a5d7f2842f1eec10c689d447 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-19-5e8136f6a6503ae9bef9beca80fada13 b/sql/hive/src/test/resources/golden/date_udf-19-5e8136f6a6503ae9bef9beca80fada13 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-2-10e337c34d1e82a360b8599988f4b266 b/sql/hive/src/test/resources/golden/date_udf-2-10e337c34d1e82a360b8599988f4b266 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-20-10e337c34d1e82a360b8599988f4b266 b/sql/hive/src/test/resources/golden/date_udf-20-10e337c34d1e82a360b8599988f4b266 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-3-29e406e613c0284b3e16a8943a4d31bd b/sql/hive/src/test/resources/golden/date_udf-3-29e406e613c0284b3e16a8943a4d31bd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-4-23653315213f578856ab5c3bd80c0264 b/sql/hive/src/test/resources/golden/date_udf-4-23653315213f578856ab5c3bd80c0264 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-5-891fd92a4787b9789f6d1f51c1eddc8a b/sql/hive/src/test/resources/golden/date_udf-5-891fd92a4787b9789f6d1f51c1eddc8a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-6-3473c118d20783eafb456043a2ee5d5b b/sql/hive/src/test/resources/golden/date_udf-6-3473c118d20783eafb456043a2ee5d5b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-7-9fb5165824e161074565e7500959c1b2 b/sql/hive/src/test/resources/golden/date_udf-7-9fb5165824e161074565e7500959c1b2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 b/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 new file mode 100644 index 0000000000000..18d17ea11b53e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 @@ -0,0 +1 @@ +1304665200 2011 5 6 6 18 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 b/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/partition_date-0-7ec1f3a845e2c49191460e15af30aa30 b/sql/hive/src/test/resources/golden/partition_date-0-7ec1f3a845e2c49191460e15af30aa30 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-1-916193405ce5e020dcd32c58325db6fe b/sql/hive/src/test/resources/golden/partition_date-1-916193405ce5e020dcd32c58325db6fe new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 b/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 new file mode 100644 index 0000000000000..7ed6ff82de6bc --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 @@ -0,0 +1 @@ +5 diff --git a/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 b/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 new file mode 100644 index 0000000000000..b4de394767536 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 @@ -0,0 +1 @@ +11 diff --git a/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 b/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 new file mode 100644 index 0000000000000..64bb6b746dcea --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 @@ -0,0 +1 @@ +30 diff --git a/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf b/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce b/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f b/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c b/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c new file mode 100644 index 0000000000000..f599e28b8ab0d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c @@ -0,0 +1 @@ +10 diff --git a/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e b/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 b/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 new file mode 100644 index 0000000000000..f599e28b8ab0d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 @@ -0,0 +1 @@ +10 diff --git a/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 b/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 b/sql/hive/src/test/resources/golden/partition_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-20-7ec1f3a845e2c49191460e15af30aa30 b/sql/hive/src/test/resources/golden/partition_date-20-7ec1f3a845e2c49191460e15af30aa30 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-3-c938b08f57d588926a5d5fbfa4531012 b/sql/hive/src/test/resources/golden/partition_date-3-c938b08f57d588926a5d5fbfa4531012 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 b/sql/hive/src/test/resources/golden/partition_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-5-a855aba47876561fd4fb095e09580686 b/sql/hive/src/test/resources/golden/partition_date-5-a855aba47876561fd4fb095e09580686 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc b/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc new file mode 100644 index 0000000000000..051ca3d3c28e7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc @@ -0,0 +1,2 @@ +2000-01-01 +2013-08-08 diff --git a/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 b/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 new file mode 100644 index 0000000000000..24192eefd2caf --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 @@ -0,0 +1,5 @@ +165 val_165 2000-01-01 2 +238 val_238 2000-01-01 2 +27 val_27 2000-01-01 2 +311 val_311 2000-01-01 2 +86 val_86 2000-01-01 2 diff --git a/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 b/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 new file mode 100644 index 0000000000000..60d3b2f4a4cd5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 @@ -0,0 +1 @@ +15 diff --git a/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f b/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f new file mode 100644 index 0000000000000..60d3b2f4a4cd5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f @@ -0,0 +1 @@ +15 diff --git a/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac b/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac new file mode 100644 index 0000000000000..91ba621412d72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac @@ -0,0 +1,6 @@ +1 11 2008-01-01 +2 12 2008-01-01 +3 13 2008-01-01 +7 17 2008-01-01 +8 18 2008-01-01 +8 28 2008-01-01 diff --git a/sql/hive/src/test/resources/golden/partition_type_check-13-45fb706ff448da1fe609c7ff76a80d4d b/sql/hive/src/test/resources/golden/partition_type_check-13-45fb706ff448da1fe609c7ff76a80d4d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd b/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd new file mode 100644 index 0000000000000..7941f53d8d4c7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd @@ -0,0 +1,40 @@ +1064 2000-11-20 +1064 2000-11-20 +1142 2000-11-21 +1142 2000-11-21 +1599 2000-11-22 +1599 2000-11-22 +361 2000-11-23 +361 2000-11-23 +897 2000-11-24 +897 2000-11-24 +1531 2000-11-25 +1531 2000-11-25 +1610 2000-11-26 +1610 2000-11-26 +3198 2000-11-27 +3198 2000-11-27 +1064 2000-11-28 +1064 2000-11-28 +1142 2000-11-28 +1142 2000-11-28 +1064 2010-10-20 +1064 2010-10-20 +1142 2010-10-21 +1142 2010-10-21 +1599 2010-10-22 +1599 2010-10-22 +361 2010-10-23 +361 2010-10-23 +897 2010-10-24 +897 2010-10-24 +1531 2010-10-25 +1531 2010-10-25 +1610 2010-10-26 +1610 2010-10-26 +3198 2010-10-27 +3198 2010-10-27 +1064 2010-10-28 +1064 2010-10-28 +1142 2010-10-29 +1142 2010-10-29 diff --git a/sql/hive/src/test/resources/golden/union_date-7-a0bade1c77338d4f72962389a1f5bea2 b/sql/hive/src/test/resources/golden/union_date-7-a0bade1c77338d4f72962389a1f5bea2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/union_date-8-21306adbd8be8ad75174ad9d3e42b73c b/sql/hive/src/test/resources/golden/union_date-8-21306adbd8be8ad75174ad9d3e42b73c new file mode 100644 index 0000000000000..e69de29bb2d1d From 56102dc2d849c221f325a7888cd66abb640ec887 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 13 Oct 2014 13:36:39 -0700 Subject: [PATCH 277/315] [SPARK-2066][SQL] Adds checks for non-aggregate attributes with aggregation This PR adds a new rule `CheckAggregation` to the analyzer to provide better error message for non-aggregate attributes with aggregation. Author: Cheng Lian Closes #2774 from liancheng/non-aggregate-attr and squashes the following commits: 5246004 [Cheng Lian] Passes test suites bf1878d [Cheng Lian] Adds checks for non-aggregate attributes with aggregation --- .../sql/catalyst/analysis/Analyzer.scala | 36 ++++++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ++++++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fe83eb12502dc..82553063145b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, - CheckResolution), + CheckResolution, + CheckAggregation), Batch("AnalysisOperators", fixedPoint, EliminateAnalysisOperators) ) @@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * Checks for non-aggregated attributes with aggregation + */ + object CheckAggregation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan.transform { + case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => + def isValidAggregateExpression(expr: Expression): Boolean = expr match { + case _: AggregateExpression => true + case e: Attribute => groupingExprs.contains(e) + case e if groupingExprs.contains(e) => true + case e if e.references.isEmpty => true + case e => e.children.forall(isValidAggregateExpression) + } + + aggregateExprs.foreach { e => + if (!isValidAggregateExpression(e)) { + throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") + } + } + + aggregatePlan + } + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ @@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) + case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => { val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs - + Project(aggregate.output, Filter(evaluatedCondition.toAttribute, aggregate.copy(aggregateExpressions = aggExprsWithHaving))) } - } - + protected def containsAggregate(condition: Expression): Boolean = condition .collect { case ae: AggregateExpression => ae } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a94022c0cf6e3..15f6ba4f72bbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll @@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) } + + test("throw errors for non-aggregate attributes with aggregation") { + def checkAggregation(query: String, isInvalidQuery: Boolean = true) { + val logicalPlan = sql(query).queryExecution.logical + + if (isInvalidQuery) { + val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) + assert( + e.getMessage.startsWith("Expression not in GROUP BY"), + "Non-aggregate attribute(s) not detected\n" + logicalPlan) + } else { + // Should not throw + sql(query).queryExecution.analyzed + } + } + + checkAggregation("SELECT key, COUNT(*) FROM testData") + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + + checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") + checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) + + checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") + checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + } } From d3cdf9128ae828dc7f1893439f66a0de68c6e527 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 13 Oct 2014 13:40:20 -0700 Subject: [PATCH 278/315] [SPARK-3529] [SQL] Delete the temp files after test exit There are lots of temporal files created by TestHive under the /tmp by default, which may cause potential performance issue for testing. This PR will automatically delete them after test exit. Author: Cheng Hao Closes #2393 from chenghao-intel/delete_temp_on_exit and squashes the following commits: 3a6511f [Cheng Hao] Remove the temp dir after text exit --- .../main/scala/org/apache/spark/sql/hive/TestHive.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a4354c1379c63..9a9e2eda6bcd4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ @@ -71,11 +72,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") setConf("hive.metastore.warehouse.dir", warehousePath) + Utils.registerShutdownDeleteDir(new File(warehousePath)) + Utils.registerShutdownDeleteDir(new File(metastorePath)) } val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") testTempDir.delete() testTempDir.mkdir() + Utils.registerShutdownDeleteDir(testTempDir) // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) @@ -121,8 +125,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - hiveFilesTemp.deleteOnExit() - + Utils.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) From 73da9c26b0e2e8bf0ab055906211727a7097c963 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 13 Oct 2014 13:43:41 -0700 Subject: [PATCH 279/315] [SPARK-3771][SQL] AppendingParquetOutputFormat should use reflection to prevent from breaking binary-compatibility. Original problem is [SPARK-3764](https://issues.apache.org/jira/browse/SPARK-3764). `AppendingParquetOutputFormat` uses a binary-incompatible method `context.getTaskAttemptID`. This causes binary-incompatible of Spark itself, i.e. if Spark itself is built against hadoop-1, the artifact is for only hadoop-1, and vice versa. Author: Takuya UESHIN Closes #2638 from ueshin/issues/SPARK-3771 and squashes the following commits: efd3784 [Takuya UESHIN] Add a comment to explain the reason to use reflection. ec213c1 [Takuya UESHIN] Use reflection to prevent breaking binary-compatibility. --- .../spark/sql/parquet/ParquetTableOperations.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ffb732347d30a..1f4237d7ede65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -331,13 +331,21 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val taskId: TaskID = context.getTaskAttemptID.getTaskID + val taskId: TaskID = getTaskAttemptID(context).getTaskID val partition: Int = taskId.getId val filename = s"part-r-${partition + offset}.parquet" val committer: FileOutputCommitter = getOutputCommitter(context).asInstanceOf[FileOutputCommitter] new Path(committer.getWorkPath, filename) } + + // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2. + // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions + // are the same, so the method calls are source-compatible but NOT binary-compatible because + // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE. + private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { + context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] + } } /** From e10d71e7e58bf2ec0f1942cb2f0602396ab866b4 Mon Sep 17 00:00:00 2001 From: Venkata Ramana Gollamudi Date: Mon, 13 Oct 2014 13:45:34 -0700 Subject: [PATCH 280/315] [SPARK-3559][SQL] Remove unnecessary columns from List of needed Column Ids in Hive Conf Author: Venkata Ramana G Author: Venkata Ramana Gollamudi Closes #2713 from gvramana/remove_unnecessary_columns and squashes the following commits: b7ba768 [Venkata Ramana Gollamudi] Added comment and checkstyle fix 6a93459 [Venkata Ramana Gollamudi] cloned hiveconf for each TableScanOperators so that only required columns are added --- .../scala/org/apache/spark/sql/hive/TableReader.scala | 6 ++++-- .../spark/sql/hive/execution/HiveTableScan.scala | 10 ++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 84fafcde63d05..0de29d5cffd0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} @@ -52,7 +53,8 @@ private[hive] class HadoopTableReader( @transient attributes: Seq[Attribute], @transient relation: MetastoreRelation, - @transient sc: HiveContext) + @transient sc: HiveContext, + @transient hiveExtraConf: HiveConf) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -63,7 +65,7 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf)) + sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) def broadcastedHiveConf = _broadcastedHiveConf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 577ca928b43b6..a32147584f6f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -64,8 +64,14 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } + // Create a local copy of hiveconf,so that scan specific modifications should not impact + // other queries @transient - private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context) + private[this] val hiveExtraConf = new HiveConf(context.hiveconf) + + @transient + private[this] val hadoopReader = + new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -97,7 +103,7 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(",")) } - addColumnMetadataToConf(context.hiveconf) + addColumnMetadataToConf(hiveExtraConf) /** * Prunes partitions not involve the query plan. From 371321cadee7df39258bd374eb59c1e32451d96b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 13 Oct 2014 13:46:34 -0700 Subject: [PATCH 281/315] [SQL] Add type checking debugging functions Adds some functions that were very useful when trying to track down the bug from #2656. This change also changes the tree output for query plans to include the `'` prefix to unresolved nodes and `!` prefix to nodes that refer to non-existent attributes. Author: Michael Armbrust Closes #2657 from marmbrus/debugging and squashes the following commits: 654b926 [Michael Armbrust] Clean-up, add tests 763af15 [Michael Armbrust] Add typeChecking debugging functions 8c69303 [Michael Armbrust] Add inputSet, references to QueryPlan. Improve tree string with a prefix to denote invalid or unresolved nodes. fbeab54 [Michael Armbrust] Better toString, factories for AttributeSet. --- .../catalyst/expressions/AttributeSet.scala | 23 +++-- .../sql/catalyst/expressions/Projection.scala | 2 + .../expressions/namedExpressions.scala | 4 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 23 +++++ .../catalyst/plans/logical/LogicalPlan.scala | 8 +- .../plans/logical/basicOperators.scala | 5 -- .../spark/sql/execution/debug/package.scala | 85 +++++++++++++++++++ .../sql/execution/debug/DebuggingSuite.scala | 33 +++++++ 8 files changed, 163 insertions(+), 20 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index c3a08bbdb6bc7..2b4969b7cfec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,19 +17,26 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.Star + protected class AttributeEquals(val a: Attribute) { override def hashCode() = a.exprId.hashCode() - override def equals(other: Any) = other match { - case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId - case otherAttribute => false + override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId + case (a1, a2) => a1 == a2 } } object AttributeSet { - /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */ - def apply(baseSet: Seq[Attribute]) = { - new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) - } + def apply(a: Attribute) = + new AttributeSet(Set(new AttributeEquals(a))) + + /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ + def apply(baseSet: Seq[Expression]) = + new AttributeSet( + baseSet + .flatMap(_.references) + .map(new AttributeEquals(_)).toSet) } /** @@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all // sorts of things in its closure. override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq + + override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 204904ecf04db..e7e81a21fdf03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -39,6 +39,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { } new GenericRow(outputArray) } + + override def toString = s"Row => [${exprArray.mkString(",")}]" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e5a958d599393..d023db44d8543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -57,6 +57,8 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => + override def references = AttributeSet(this) + def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute @@ -116,8 +118,6 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def references = AttributeSet(this :: Nil) - override def equals(other: Any) = other match { case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index af9e4d86e995a..dcbbb62c0aca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -31,6 +31,25 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy */ def outputSet: AttributeSet = AttributeSet(output) + /** + * All Attributes that appear in expressions from this operator. Note that this set does not + * include attributes that are implicitly referenced by being passed through to the output tuple. + */ + def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) + + /** + * The set of all attributes that are input to this operator by its children. + */ + def inputSet: AttributeSet = + AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + + /** + * Attributes that are referenced by expressions but not provided by this nodes children. + * Subclasses should override this method if they produce attributes internally as it is used by + * assertions designed to prevent the construction of invalid plans. + */ + def missingInput: AttributeSet = references -- inputSet + /** * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, @@ -132,4 +151,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) + + protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" + + override def simpleString = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 4f8ad8a7e0223..882e9c6110089 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -53,12 +53,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) } - /** - * Returns the set of attributes that this node takes as - * input from its children. - */ - lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output)) - /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan @@ -68,6 +62,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved + override protected def statePrefix = if (!resolved) "'" else super.statePrefix + /** * Returns true if all its children of this query plan have been resolved. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f8e9930ac270d..14b03c7445c13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -138,11 +138,6 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - /** The set of all AttributeReferences required for this aggregation. */ - def references = - AttributeSet( - groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references)) - override def output = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index a9535a750bcd7..61be5ed2db65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ import org.apache.spark.sql.{SchemaRDD, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.types._ /** * :: DeveloperApi :: @@ -56,6 +57,23 @@ package object debug { case _ => } } + + def typeCheck(): Unit = { + val plan = query.queryExecution.executedPlan + val visited = new collection.mutable.HashSet[TreeNodeRef]() + val debugPlan = plan transform { + case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => + visited += new TreeNodeRef(s) + TypeCheck(s) + } + try { + println(s"Results returned: ${debugPlan.execute().count()}") + } catch { + case e: Exception => + def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) + println(s"Deepest Error: ${unwrap(e)}") + } + } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { @@ -115,4 +133,71 @@ package object debug { } } } + + /** + * :: DeveloperApi :: + * Helper functions for checking that runtime types match a given schema. + */ + @DeveloperApi + object TypeCheck { + def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { + case (null, _) => + + case (row: Row, StructType(fields)) => + row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) } + case (s: Seq[_], ArrayType(elemType, _)) => + s.foreach(typeCheck(_, elemType)) + case (m: Map[_, _], MapType(keyType, valueType, _)) => + m.keys.foreach(typeCheck(_, keyType)) + m.values.foreach(typeCheck(_, valueType)) + + case (_: Long, LongType) => + case (_: Int, IntegerType) => + case (_: String, StringType) => + case (_: Float, FloatType) => + case (_: Byte, ByteType) => + case (_: Short, ShortType) => + case (_: Boolean, BooleanType) => + case (_: Double, DoubleType) => + + case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") + } + } + + /** + * :: DeveloperApi :: + * Augments SchemaRDDs with debug methods. + */ + @DeveloperApi + private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { + import TypeCheck._ + + override def nodeName = "" + + /* Only required when defining this class in a REPL. + override def makeCopy(args: Array[Object]): this.type = + TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] + */ + + def output = child.output + + def children = child :: Nil + + def execute() = { + child.execute().map { row => + try typeCheck(row, child.schema) catch { + case e: Exception => + sys.error( + s""" + |ERROR WHEN TYPE CHECKING QUERY + |============================== + |$e + |======== BAD TREE ============ + |$child + """.stripMargin) + } + row + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala new file mode 100644 index 0000000000000..87c28c334d228 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.debug + +import org.scalatest.FunSuite + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ + +class DebuggingSuite extends FunSuite { + test("SchemaRDD.debug()") { + testData.debug() + } + + test("SchemaRDD.typeCheck()") { + testData.typeCheck() + } +} \ No newline at end of file From e6e37701f12be82fa77dc28d825ddd36a1ab7594 Mon Sep 17 00:00:00 2001 From: chirag Date: Mon, 13 Oct 2014 13:47:26 -0700 Subject: [PATCH 282/315] SPARK-3807: SparkSql does not work for tables created using custom serde SparkSql crashes on selecting tables using custom serde. Example: ---------------- CREATE EXTERNAL TABLE table_name PARTITIONED BY ( a int) ROW FORMAT 'SERDE "org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer" with serdeproperties("serialization.format"="org.apache.thrift.protocol.TBinaryProtocol","serialization.class"="ser_class") STORED AS SEQUENCEFILE; The following exception is seen on running a query like 'select * from table_name limit 1': ERROR CliDriver: org.apache.hadoop.hive.serde2.SerDeException: java.lang.NullPointerException at org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer.initialize(ThriftDeserializer.java:68) at org.apache.hadoop.hive.ql.plan.TableDesc.getDeserializer(TableDesc.java:80) at org.apache.spark.sql.hive.execution.HiveTableScan.addColumnMetadataToConf(HiveTableScan.scala:86) at org.apache.spark.sql.hive.execution.HiveTableScan.(HiveTableScan.scala:100) at org.apache.spark.sql.hive.HiveStrategies$HiveTableScans$$anonfun$14.apply(HiveStrategies.scala:188) at org.apache.spark.sql.hive.HiveStrategies$HiveTableScans$$anonfun$14.apply(HiveStrategies.scala:188) at org.apache.spark.sql.SQLContext$SparkPlanner.pruneFilterProject(SQLContext.scala:364) at org.apache.spark.sql.hive.HiveStrategies$HiveTableScans$.apply(HiveStrategies.scala:184) at org.apache.spark.sql.catalyst.planning.QueryPlanner$$anonfun$1.apply(QueryPlanner.scala:58) at org.apache.spark.sql.catalyst.planning.QueryPlanner$$anonfun$1.apply(QueryPlanner.scala:58) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at org.apache.spark.sql.catalyst.planning.QueryPlanner.apply(QueryPlanner.scala:59) at org.apache.spark.sql.catalyst.planning.QueryPlanner.planLater(QueryPlanner.scala:54) at org.apache.spark.sql.execution.SparkStrategies$BasicOperators$.apply(SparkStrategies.scala:280) at org.apache.spark.sql.catalyst.planning.QueryPlanner$$anonfun$1.apply(QueryPlanner.scala:58) at org.apache.spark.sql.catalyst.planning.QueryPlanner$$anonfun$1.apply(QueryPlanner.scala:58) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at org.apache.spark.sql.catalyst.planning.QueryPlanner.apply(QueryPlanner.scala:59) at org.apache.spark.sql.SQLContext$QueryExecution.sparkPlan$lzycompute(SQLContext.scala:402) at org.apache.spark.sql.SQLContext$QueryExecution.sparkPlan(SQLContext.scala:400) at org.apache.spark.sql.SQLContext$QueryExecution.executedPlan$lzycompute(SQLContext.scala:406) at org.apache.spark.sql.SQLContext$QueryExecution.executedPlan(SQLContext.scala:406) at org.apache.spark.sql.hive.HiveContext$QueryExecution.stringResult(HiveContext.scala:406) at org.apache.spark.sql.hive.thriftserver.SparkSQLDriver.run(SparkSQLDriver.scala:59) at org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver.processCmd(SparkSQLCLIDriver.scala:291) at org.apache.hadoop.hive.cli.CliDriver.processLine(CliDriver.java:413) at org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver$.main(SparkSQLCLIDriver.scala:226) at org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver.main(SparkSQLCLIDriver.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source) at java.lang.reflect.Method.invoke(Unknown Source) at org.apache.spark.deploy.SparkSubmit$.launch(SparkSubmit.scala:328) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:75) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: java.lang.NullPointerException Author: chirag Closes #2674 from chiragaggarwal/branch-1.1 and squashes the following commits: 370c31b [chirag] SPARK-3807: Add a test case to validate the fix. 1f26805 [chirag] SPARK-3807: SparkSql does not work for tables created using custom serde (Incorporated Review Comments) ba4bc0c [chirag] SPARK-3807: SparkSql does not work for tables created using custom serde 5c73b72 [chirag] SPARK-3807: SparkSql does not work for tables created using custom serde (cherry picked from commit 925e22d3132b983a2fcee31e3878b680c7ff92da) Signed-off-by: Michael Armbrust --- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../org/apache/spark/sql/hive/execution/HiveTableScan.scala | 6 +++++- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c5fee5e4702f6..75a19656af110 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -305,7 +305,7 @@ private[hive] case class MetastoreRelation val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = table.getSd.getCols.map(_.toAttribute) + val attributes = hiveQlTable.getCols.map(_.toAttribute) val output = attributes ++ partitionKeys } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index a32147584f6f4..5b83b77d80a22 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -86,10 +86,14 @@ case class HiveTableScan( ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) + val tableDesc = relation.tableDesc + val deserializer = tableDesc.getDeserializerClass.newInstance + deserializer.initialize(hiveConf, tableDesc.getProperties) + // Specifies types and object inspectors of columns to be scanned. val structOI = ObjectInspectorUtils .getStandardObjectInspector( - relation.tableDesc.getDeserializer.getObjectInspector, + deserializer.getObjectInspector, ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2829105f43716..3e100775e4981 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -802,6 +802,9 @@ class HiveQuerySuite extends HiveComparisonTest { clear() } + createQueryTest("select from thrift based table", + "SELECT * from src_thrift") + // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. } From 9d9ca91fef70eca6fc576be9c99aed5d8ce6e68b Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Mon, 13 Oct 2014 13:49:11 -0700 Subject: [PATCH 283/315] [SQL]Small bug in unresolved.scala name should throw exception with name instead of exprId. Author: Liquan Pei Closes #2758 from Ishiihara/SparkSQL-bug and squashes the following commits: aa36a3b [Liquan Pei] small bug --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 67570a6f73c36..77d84e1687e1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -88,7 +88,7 @@ case class Star( mapFunction: Attribute => Expression = identity[Attribute]) extends Attribute with trees.LeafNode[Expression] { - override def name = throw new UnresolvedException(this, "exprId") + override def name = throw new UnresolvedException(this, "name") override def exprId = throw new UnresolvedException(this, "exprId") override def dataType = throw new UnresolvedException(this, "dataType") override def nullable = throw new UnresolvedException(this, "nullable") From 9eb49d4134e23a15142fb592d54d920e89bd8786 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 13 Oct 2014 13:50:27 -0700 Subject: [PATCH 284/315] [SPARK-3809][SQL] Fixes test suites in hive-thriftserver As scwf pointed out, `HiveThriftServer2Suite` isn't effective anymore after the Thrift server was made a daemon. On the other hand, these test suites were known flaky, PR #2214 tried to fix them but failed because of unknown Jenkins build error. This PR fixes both sets of issues. In this PR, instead of watching `start-thriftserver.sh` output, the test code start a `tail` process to watch the log file. A `Thread.sleep` has to be introduced because the `kill` command used in `stop-thriftserver.sh` is not synchronous. As for the root cause of the mysterious Jenkins build failure. Please refer to [this comment](https://github.com/apache/spark/pull/2675#issuecomment-58464189) below for details. ---- (Copied from PR description of #2214) This PR fixes two issues of `HiveThriftServer2Suite` and brings 1 enhancement: 1. Although metastore, warehouse directories and listening port are randomly chosen, all test cases share the same configuration. Due to parallel test execution, one of the two test case is doomed to fail 2. We caught any exceptions thrown from a test case and print diagnosis information, but forgot to re-throw the exception... 3. When the forked server process ends prematurely (e.g., fails to start), the `serverRunning` promise is completed with a failure, preventing the test code to keep waiting until timeout. So, embarrassingly, this test suite was failing continuously for several days but no one had ever noticed it... Fortunately no bugs in the production code were covered under the hood. Author: Cheng Lian Author: wangfei Closes #2675 from liancheng/fix-thriftserver-tests and squashes the following commits: 1c384b7 [Cheng Lian] Minor code cleanup, restore the logging level hack in TestHive.scala 7805c33 [wangfei] reset SPARK_TESTING to avoid loading Log4J configurations in testing class paths af2b5a9 [Cheng Lian] Removes log level hacks from TestHiveContext d116405 [wangfei] make sure that log4j level is INFO ee92a82 [Cheng Lian] Relaxes timeout 7fd6757 [Cheng Lian] Fixes test suites in hive-thriftserver --- .../sql/hive/thriftserver/CliSuite.scala | 13 ++- .../thriftserver/HiveThriftServer2Suite.scala | 86 +++++++++++-------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index d68dd090b5e6c..fc97a25be34be 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.sql.catalyst.util.getTempFilePath class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { @@ -62,8 +62,11 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def captureOutput(source: String)(line: String) { buffer += s"$source> $line" + // If we haven't found all expected answers... if (next.get() < expectedAnswers.size) { + // If another expected answer is found... if (line.startsWith(expectedAnswers(next.get()))) { + // If all expected answers have been found... if (next.incrementAndGet() == expectedAnswers.size) { foundAllExpectedAnswers.trySuccess(()) } @@ -77,7 +80,8 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { Future { val exitValue = process.exitValue() - logInfo(s"Spark SQL CLI process exit value: $exitValue") + foundAllExpectedAnswers.tryFailure( + new SparkException(s"Spark SQL CLI process exit value: $exitValue")) } try { @@ -98,6 +102,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |End CliSuite failure output |=========================== """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() @@ -109,7 +114,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - runCliWithin(1.minute)( + runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" @@ -120,7 +125,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { -> "Time taken: ", "SELECT COUNT(*) FROM hive_test;" -> "5", - "DROP TABLE hive_test" + "DROP TABLE hive_test;" -> "Time taken: " ) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 38977ff162097..e3b4e45a3d68c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} -import scala.sys.process.{Process, ProcessLogger} - import java.io.File import java.net.ServerSocket import java.sql.{DriverManager, Statement} import java.util.concurrent.TimeoutException +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.sys.process.{Process, ProcessLogger} +import scala.util.Try + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.scalatest.FunSuite @@ -41,25 +41,25 @@ import org.apache.spark.sql.catalyst.util.getTempFilePath class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - private val listeningHost = "localhost" - private val listeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - - private val warehousePath = getTempFilePath("warehouse") - private val metastorePath = getTempFilePath("metastore") - private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - - def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) { - val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + val warehousePath = getTempFilePath("warehouse") + val metastorePath = getTempFilePath("metastore") + val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + val listeningHost = "localhost" + val listeningPort = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } val command = - s"""$serverScript + s"""$startScript | --master local | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri @@ -68,29 +68,40 @@ class HiveThriftServer2Suite extends FunSuite with Logging { | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort """.stripMargin.split("\\s+").toSeq - val serverStarted = Promise[Unit]() + val serverRunning = Promise[Unit]() val buffer = new ArrayBuffer[String]() + val LOGGING_MARK = + s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to " + var logTailingProcess: Process = null + var logFilePath: String = null - def captureOutput(source: String)(line: String) { - buffer += s"$source> $line" + def captureLogOutput(line: String): Unit = { + buffer += line if (line.contains("ThriftBinaryCLIService listening on")) { - serverStarted.success(()) + serverRunning.success(()) } } - val process = Process(command).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - - Future { - val exitValue = process.exitValue() - logInfo(s"Spark SQL Thrift server process exit value: $exitValue") + def captureThriftServerOutput(source: String)(line: String): Unit = { + if (line.startsWith(LOGGING_MARK)) { + logFilePath = line.drop(LOGGING_MARK.length).trim + // Ensure that the log file is created so that the `tail' command won't fail + Try(new File(logFilePath).createNewFile()) + logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath") + .run(ProcessLogger(captureLogOutput, _ => ())) + } } + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger( + captureThriftServerOutput("stdout"), + captureThriftServerOutput("stderr"))) + val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" val user = System.getProperty("user.name") try { - Await.result(serverStarted.future, timeout) + Await.result(serverRunning.future, timeout) val connection = DriverManager.getConnection(jdbcUri, user, "") val statement = connection.createStatement() @@ -122,10 +133,15 @@ class HiveThriftServer2Suite extends FunSuite with Logging { |End HiveThriftServer2Suite failure output |========================================= """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() - process.destroy() + Process(stopScript).run().exitValue() + // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. + Thread.sleep(3.seconds.toMillis) + Option(logTailingProcess).map(_.destroy()) + Option(logFilePath).map(new File(_).delete()) } } From 4d26aca770f7dd50eee1ed7855e9eda68b5a7ffa Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 13 Oct 2014 22:46:49 -0700 Subject: [PATCH 285/315] [SPARK-3912][Streaming] Fixed flakyFlumeStreamSuite @harishreedharan @pwendell See JIRA for diagnosis of the problem https://issues.apache.org/jira/browse/SPARK-3912 The solution was to reimplement it. 1. Find a free port (by binding and releasing a server-scoket), and then use that port 2. Remove thread.sleep()s, instead repeatedly try to create a sender and send data and check whether data was sent. Use eventually() to minimize waiting time. 3. Check whether all the data was received, without caring about batches. Author: Tathagata Das Closes #2773 from tdas/flume-test-fix and squashes the following commits: 93cd7f6 [Tathagata Das] Reimplimented FlumeStreamSuite to be more robust. --- .../streaming/flume/FlumeStreamSuite.scala | 166 +++++++++++------- 1 file changed, 102 insertions(+), 64 deletions(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 33235d150b4a5..13943ed5442b9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,103 +17,141 @@ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer import java.nio.charset.Charset +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression._ +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuiteBase} -import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.handler.codec.compression._ +class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { + val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") + + var ssc: StreamingContext = null + var transceiver: NettyTransceiver = null -class FlumeStreamSuite extends TestSuiteBase { + after { + if (ssc != null) { + ssc.stop() + } + if (transceiver != null) { + transceiver.close() + } + } test("flume input stream") { - runFlumeStreamTest(false) + testFlumeStream(testCompression = false) } test("flume input compressed stream") { - runFlumeStreamTest(true) + testFlumeStream(testCompression = true) + } + + /** Run test on flume stream */ + private def testFlumeStream(testCompression: Boolean): Unit = { + val input = (1 to 100).map { _.toString } + val testPort = findFreePort() + val outputBuffer = startContext(testPort, testCompression) + writeAndVerify(input, testPort, outputBuffer, testCompression) + } + + /** Find a free port */ + private def findFreePort(): Int = { + Utils.startServiceOnPort(23456, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + })._2 } - - def runFlumeStreamTest(enableDecompression: Boolean) { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val (flumeStream, testPort) = - Utils.startServiceOnPort(9997, (trialPort: Int) => { - val dstream = FlumeUtils.createStream( - ssc, "localhost", trialPort, StorageLevel.MEMORY_AND_DISK, enableDecompression) - (dstream, trialPort) - }) + /** Setup and start the streaming context */ + private def startContext( + testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = { + ssc = new StreamingContext(conf, Milliseconds(200)) + val flumeStream = FlumeUtils.createStream( + ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() + outputBuffer + } - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - Thread.sleep(1000) - val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)) - var client: AvroSourceProtocol = null - - if (enableDecompression) { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], - new NettyTransceiver(new InetSocketAddress("localhost", testPort), - new CompressionChannelFactory(6))) - } else { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], transceiver) - } + /** Send data to the flume receiver and verify whether the data was received */ + private def writeAndVerify( + input: Seq[String], + testPort: Int, + outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], + enableCompression: Boolean + ) { + val testAddress = new InetSocketAddress("localhost", testPort) - for (i <- 0 until input.size) { + val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(input(i).toString.getBytes("utf-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - client.append(event) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) + event } - Thread.sleep(1000) - - val startTime = System.currentTimeMillis() - while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) - Thread.sleep(100) + eventually(timeout(10 seconds), interval(100 milliseconds)) { + // if last attempted transceiver had succeeded, close it + if (transceiver != null) { + transceiver.close() + transceiver = null + } + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + client should not be null + + // Send data + val status = client.appendBatch(inputEvents.toList) + status should be (avro.Status.OK) } - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - val decoder = Charset.forName("UTF-8").newDecoder() - - assert(outputBuffer.size === input.length) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - val str = decoder.decode(outputBuffer(i).head.event.getBody) - assert(str.toString === input(i).toString) - assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") + + val decoder = Charset.forName("UTF-8").newDecoder() + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + output should be (input) } } - class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) From 186b497c945cc7bbe7a21fef56a948dd1fd10774 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 13 Oct 2014 23:31:37 -0700 Subject: [PATCH 286/315] [SPARK-3921] Fix CoarseGrainedExecutorBackend's arguments for Standalone mode The goal of this patch is to fix the swapped arguments in standalone mode, which was caused by https://github.com/apache/spark/commit/79e45c9323455a51f25ed9acd0edd8682b4bbb88#diff-79391110e9f26657e415aa169a004998R153. More details can be found in the JIRA: [SPARK-3921](https://issues.apache.org/jira/browse/SPARK-3921) Tested in Standalone mode, but not in Mesos. Author: Aaron Davidson Closes #2779 from aarondav/fix-standalone and squashes the following commits: 725227a [Aaron Davidson] Fix ExecutorRunnerTest 9d703fe [Aaron Davidson] [SPARK-3921] Fix CoarseGrainedExecutorBackend's arguments for Standalone mode --- .../apache/spark/deploy/worker/ExecutorRunner.scala | 3 ++- .../spark/executor/CoarseGrainedExecutorBackend.scala | 3 +++ .../cluster/SparkDeploySchedulerBackend.scala | 3 ++- .../cluster/mesos/CoarseMesosSchedulerBackend.scala | 8 ++++---- .../spark/deploy/worker/ExecutorRunnerTest.scala | 10 ++++------ 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 71650cd773bcf..71d7385b08eb9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -111,13 +111,14 @@ private[spark] class ExecutorRunner( case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host case "{{CORES}}" => cores.toString + case "{{APP_ID}}" => appId case other => other } def getCommandSeq = { val command = Command( appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), + appDesc.command.arguments.map(substituteVariables), appDesc.command.environment, appDesc.command.classPathEntries, appDesc.command.libraryPathEntries, diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 06061edfc0844..c40a3e16675ad 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -152,6 +152,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { "Usage: CoarseGrainedExecutorBackend " + " [] ") System.exit(1) + + // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) + // and CoarseMesosSchedulerBackend (for mesos mode). case 5 => run(args(0), args(1), args(2), args(3).toInt, args(4), None) case x if x > 5 => diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ed209d195ec9d..8c7de75600b5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -51,7 +51,8 @@ private[spark] class SparkDeploySchedulerBackend( conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", + "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp => diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 90828578cd88f..d7f88de4b40aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -150,17 +150,17 @@ private[spark] class CoarseMesosSchedulerBackend( if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( - runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( ("cd %s*; " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d") + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") .format(basename, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores)) + offer.getHostname, numCores, appId)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } command.build() diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 39ab53cf0b5b1..5e2592e8d2e8d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -26,14 +26,12 @@ import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { - def f(s:String) = new File(s) + val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, - Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl") - val appId = "12345-worker321-9876" - val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome), - f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) - + Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", + new File(sparkHome), new File("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) assert(er.getCommandSeq.last === appId) } } From 9b6de6fbc00b184d81fc28ac160d03451fad80ec Mon Sep 17 00:00:00 2001 From: Bill Bejeck Date: Tue, 14 Oct 2014 12:12:38 -0700 Subject: [PATCH 287/315] SPARK-3178 setting SPARK_WORKER_MEMORY to a value without a label (m or g) sets the worker memory limit to zero Validate the memory is greater than zero when set from the SPARK_WORKER_MEMORY environment variable or command line without a g or m label. Added unit tests. If memory is 0 an IllegalStateException is thrown. Updated unit tests to mock environment variables by subclassing SparkConf (tip provided by Josh Rosen). Updated WorkerArguments to use SparkConf.getenv instead of System.getenv for reading the SPARK_WORKER_MEMORY environment variable. Author: Bill Bejeck Closes #2309 from bbejeck/spark-memory-worker and squashes the following commits: 51cf915 [Bill Bejeck] SPARK-3178 - Validate the memory is greater than zero when set from the SPARK_WORKER_MEMORY environment variable or command line without a g or m label. Added unit tests. If memory is 0 an IllegalStateException is thrown. Updated unit tests to mock environment variables by subclassing SparkConf (tip provided by Josh Rosen). Updated WorkerArguments to use SparkConf.getenv instead of System.getenv for reading the SPARK_WORKER_MEMORY environment variable. --- .../spark/deploy/worker/WorkerArguments.scala | 13 ++- .../deploy/worker/WorkerArgumentsTest.scala | 82 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 1e295aaa48c30..54e3937edde6b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -41,8 +41,8 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_WORKER_CORES") != null) { cores = System.getenv("SPARK_WORKER_CORES").toInt } - if (System.getenv("SPARK_WORKER_MEMORY") != null) { - memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY")) + if (conf.getenv("SPARK_WORKER_MEMORY") != null) { + memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY")) } if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt @@ -56,6 +56,8 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { parse(args.toList) + checkWorkerMemory() + def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -153,4 +155,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } + + def checkWorkerMemory(): Unit = { + if (memory <= 0) { + val message = "Memory can't be 0, missing a M or G on the end of the memory specification?" + throw new IllegalStateException(message) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala new file mode 100644 index 0000000000000..1a28a9a187cd7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.worker + +import org.apache.spark.SparkConf +import org.scalatest.FunSuite + + +class WorkerArgumentsTest extends FunSuite { + + test("Memory can't be set to 0 when cmd line args leave off M or G") { + val conf = new SparkConf + val args = Array("-m", "10000", "spark://localhost:0000 ") + intercept[IllegalStateException] { + new WorkerArguments(args, conf) + } + } + + + test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") { + val args = Array("spark://localhost:0000 ") + + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_WORKER_MEMORY") "50000" + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + val conf = new MySparkConf() + intercept[IllegalStateException] { + new WorkerArguments(args, conf) + } + } + + test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") { + val args = Array("spark://localhost:0000 ") + + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_WORKER_MEMORY") "5G" + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + val conf = new MySparkConf() + val workerArgs = new WorkerArguments(args, conf) + assert(workerArgs.memory === 5120) + } + + test("Memory correctly set from args with M appended to memory value") { + val conf = new SparkConf + val args = Array("-m", "10000M", "spark://localhost:0000 ") + + val workerArgs = new WorkerArguments(args, conf) + assert(workerArgs.memory === 10000) + + } + +} From 7ced88b0d6b4d90c262f19afa99c02b51c0ea5ea Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 14 Oct 2014 14:09:39 -0700 Subject: [PATCH 288/315] [SPARK-3946] gitignore in /python includes wrong directory Modified to ignore not the docs/ directory, but only the docs/_build/ which is the output directory of sphinx build. Author: Masayoshi TSUZUKI Closes #2796 from tsudukim/feature/SPARK-3946 and squashes the following commits: 2bea6a9 [Masayoshi TSUZUKI] [SPARK-3946] gitignore in /python includes wrong directory --- python/.gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/.gitignore b/python/.gitignore index 80b361ffbd51c..52128cf844a79 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,5 +1,5 @@ *.pyc -docs/ +docs/_build/ pyspark.egg-info build/ dist/ From 24b818b971ba715b6796518e4c6afdecb1b16f15 Mon Sep 17 00:00:00 2001 From: shitis Date: Tue, 14 Oct 2014 14:16:45 -0700 Subject: [PATCH 289/315] [SPARK-3944][Core] Using Option[String] where value of String can be null Author: shitis Closes #2795 from Shiti/master and squashes the following commits: 46897d7 [shitis] Using Option Wrapper to convert String with value null to None --- .../scala/org/apache/spark/util/Utils.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 07477dd460a4b..aad901620f53e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -340,8 +340,8 @@ private[spark] object Utils extends Logging { val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) - uri.getScheme match { - case "http" | "https" | "ftp" => + Option(uri.getScheme) match { + case Some("http") | Some("https") | Some("ftp") => logInfo("Fetching " + url + " to " + tempFile) var uc: URLConnection = null @@ -374,7 +374,7 @@ private[spark] object Utils extends Logging { } } Files.move(tempFile, targetFile) - case "file" | null => + case Some("file") | None => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) @@ -403,7 +403,7 @@ private[spark] object Utils extends Logging { logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) Files.copy(sourceFile, targetFile) } - case _ => + case Some(other) => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val fs = getHadoopFileSystem(uri, hadoopConf) val in = fs.open(new Path(uri)) @@ -1368,16 +1368,17 @@ private[spark] object Utils extends Logging { if (uri.getPath == null) { throw new IllegalArgumentException(s"Given path is malformed: $uri") } - uri.getScheme match { - case windowsDrive(d) if windows => + + Option(uri.getScheme) match { + case Some(windowsDrive(d)) if windows => new URI("file:/" + uri.toString.stripPrefix("/")) - case null => + case None => // Preserve fragments for HDFS file name substitution (denoted by "#") // For instance, in "abc.py#xyz.py", "xyz.py" is the name observed by the application val fragment = uri.getFragment val part = new File(uri.getPath).toURI new URI(part.getScheme, part.getPath, fragment) - case _ => + case Some(other) => uri } } @@ -1399,10 +1400,11 @@ private[spark] object Utils extends Logging { } else { paths.split(",").filter { p => val formattedPath = if (windows) formatWindowsPath(p) else p - new URI(formattedPath).getScheme match { - case windowsDrive(d) if windows => false - case "local" | "file" | null => false - case _ => true + val uri = new URI(formattedPath) + Option(uri.getScheme) match { + case Some(windowsDrive(d)) if windows => false + case Some("local") | Some("file") | None => false + case Some(other) => true } } } From 56096dbaa8cb3ab39bfc2ce5827192313613b010 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 14 Oct 2014 14:42:09 -0700 Subject: [PATCH 290/315] SPARK-3803 [MLLIB] ArrayIndexOutOfBoundsException found in executing computePrincipalComponents Avoid overflow in computing n*(n+1)/2 as much as possible; throw explicit error when Gramian computation will fail due to negative array size; warn about large result when computing Gramian too Author: Sean Owen Closes #2801 from srowen/SPARK-3803 and squashes the following commits: b4e6d92 [Sean Owen] Avoid overflow in computing n*(n+1)/2 as much as possible; throw explicit error when Gramian computation will fail due to negative array size; warn about large result when computing Gramian too --- .../mllib/linalg/distributed/RowMatrix.scala | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8380058cf9b41..ec2d481dccc22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -111,7 +111,10 @@ class RowMatrix( */ def computeGramianMatrix(): Matrix = { val n = numCols().toInt - val nt: Int = n * (n + 1) / 2 + checkNumColumns(n) + // Computes n*(n+1)/2, avoiding overflow in the multiplication. + // This succeeds when n <= 65535, which is checked above + val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( @@ -123,6 +126,16 @@ class RowMatrix( RowMatrix.triuToFull(n, GU.data) } + private def checkNumColumns(cols: Int): Unit = { + if (cols > 65535) { + throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols") + } + if (cols > 10000) { + val mem = cols * cols * 8 + logWarning(s"$cols columns will require at least $mem bytes of memory!") + } + } + /** * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k @@ -301,12 +314,7 @@ class RowMatrix( */ def computeCovariance(): Matrix = { val n = numCols().toInt - - if (n > 10000) { - val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE - logWarning(s"The number of columns $n is greater than 10000! " + - s"We need at least $mem bytes of memory.") - } + checkNumColumns(n) val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), From 7b4f39f647da1f7b1b57e38827a8639243c661cb Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Tue, 14 Oct 2014 15:09:51 -0700 Subject: [PATCH 291/315] [SPARK-3869] ./bin/spark-class miss Java version with _JAVA_OPTIONS set When _JAVA_OPTIONS environment variable is set, a command "java -version" outputs a message like "Picked up _JAVA_OPTIONS: -Dfile.encoding=UTF-8". ./bin/spark-class knows java version from the first line of "java -version" output, so it mistakes java version with _JAVA_OPTIONS set. Author: cocoatomo Closes #2725 from cocoatomo/issues/3869-mistake-java-version and squashes the following commits: f894ebd [cocoatomo] [SPARK-3869] ./bin/spark-class miss Java version with _JAVA_OPTIONS set --- bin/spark-class | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/spark-class b/bin/spark-class index e8201c18d52de..91d858bc063d0 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -105,7 +105,7 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size if [ "$JAVA_VERSION" -ge 18 ]; then From 66af8e2508bfe9c9d4aecc17a19f297c98e9661d Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 14 Oct 2014 18:50:14 -0700 Subject: [PATCH 292/315] [SPARK-3943] Some scripts bin\*.cmd pollutes environment variables in Windows Modified not to pollute environment variables. Just moved the main logic into `XXX2.cmd` from `XXX.cmd`, and call `XXX2.cmd` with cmd command in `XXX.cmd`. `pyspark.cmd` and `spark-class.cmd` are already using the same way, but `spark-shell.cmd`, `spark-submit.cmd` and `/python/docs/make.bat` are not. Author: Masayoshi TSUZUKI Closes #2797 from tsudukim/feature/SPARK-3943 and squashes the following commits: b397a7d [Masayoshi TSUZUKI] [SPARK-3943] Some scripts bin\*.cmd pollutes environment variables in Windows --- bin/spark-shell.cmd | 5 +- bin/spark-shell2.cmd | 22 ++++ bin/spark-submit.cmd | 51 +-------- bin/spark-submit2.cmd | 68 ++++++++++++ python/docs/make.bat | 242 +---------------------------------------- python/docs/make2.bat | 243 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 342 insertions(+), 289 deletions(-) create mode 100644 bin/spark-shell2.cmd create mode 100644 bin/spark-submit2.cmd create mode 100644 python/docs/make2.bat diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 2ee60b4e2a2b3..8f90ba5a0b3b8 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -17,6 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem This is the entry point for running Spark shell. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell +cmd /V /E /C %~dp0spark-shell2.cmd %* diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd new file mode 100644 index 0000000000000..2ee60b4e2a2b3 --- /dev/null +++ b/bin/spark-shell2.cmd @@ -0,0 +1,22 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +set SPARK_HOME=%~dp0.. + +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index cf6046d1547ad..8f3b84c7b971d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -17,52 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! +rem This is the entry point for running Spark submit. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -set SPARK_HOME=%~dp0.. -set ORIG_ARGS=%* - -rem Reset the values of all variables used -set SPARK_SUBMIT_DEPLOY_MODE=client -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf -set SPARK_SUBMIT_DRIVER_MEMORY= -set SPARK_SUBMIT_LIBRARY_PATH= -set SPARK_SUBMIT_CLASSPATH= -set SPARK_SUBMIT_OPTS= -set SPARK_SUBMIT_BOOTSTRAP_DRIVER= - -:loop -if [%1] == [] goto continue - if [%1] == [--deploy-mode] ( - set SPARK_SUBMIT_DEPLOY_MODE=%2 - ) else if [%1] == [--properties-file] ( - set SPARK_SUBMIT_PROPERTIES_FILE=%2 - ) else if [%1] == [--driver-memory] ( - set SPARK_SUBMIT_DRIVER_MEMORY=%2 - ) else if [%1] == [--driver-library-path] ( - set SPARK_SUBMIT_LIBRARY_PATH=%2 - ) else if [%1] == [--driver-class-path] ( - set SPARK_SUBMIT_CLASSPATH=%2 - ) else if [%1] == [--driver-java-options] ( - set SPARK_SUBMIT_OPTS=%2 - ) - shift -goto loop -:continue - -rem For client mode, the driver will be launched in the same JVM that launches -rem SparkSubmit, so we may need to read the properties file for any extra class -rem paths, library paths, java options and memory early on. Otherwise, it will -rem be too late by the time the driver JVM has started. - -if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( - if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( - rem Parse the properties file only if the special configs exist - for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ - %SPARK_SUBMIT_PROPERTIES_FILE%') do ( - set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - ) - ) -) - -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% +cmd /V /E /C %~dp0spark-submit2.cmd %* diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd new file mode 100644 index 0000000000000..cf6046d1547ad --- /dev/null +++ b/bin/spark-submit2.cmd @@ -0,0 +1,68 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! + +set SPARK_HOME=%~dp0.. +set ORIG_ARGS=%* + +rem Reset the values of all variables used +set SPARK_SUBMIT_DEPLOY_MODE=client +set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf +set SPARK_SUBMIT_DRIVER_MEMORY= +set SPARK_SUBMIT_LIBRARY_PATH= +set SPARK_SUBMIT_CLASSPATH= +set SPARK_SUBMIT_OPTS= +set SPARK_SUBMIT_BOOTSTRAP_DRIVER= + +:loop +if [%1] == [] goto continue + if [%1] == [--deploy-mode] ( + set SPARK_SUBMIT_DEPLOY_MODE=%2 + ) else if [%1] == [--properties-file] ( + set SPARK_SUBMIT_PROPERTIES_FILE=%2 + ) else if [%1] == [--driver-memory] ( + set SPARK_SUBMIT_DRIVER_MEMORY=%2 + ) else if [%1] == [--driver-library-path] ( + set SPARK_SUBMIT_LIBRARY_PATH=%2 + ) else if [%1] == [--driver-class-path] ( + set SPARK_SUBMIT_CLASSPATH=%2 + ) else if [%1] == [--driver-java-options] ( + set SPARK_SUBMIT_OPTS=%2 + ) + shift +goto loop +:continue + +rem For client mode, the driver will be launched in the same JVM that launches +rem SparkSubmit, so we may need to read the properties file for any extra class +rem paths, library paths, java options and memory early on. Otherwise, it will +rem be too late by the time the driver JVM has started. + +if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( + if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( + rem Parse the properties file only if the special configs exist + for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ + %SPARK_SUBMIT_PROPERTIES_FILE%') do ( + set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 + ) + ) +) + +cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% diff --git a/python/docs/make.bat b/python/docs/make.bat index adad44fd7536a..c011e82b4a35a 100644 --- a/python/docs/make.bat +++ b/python/docs/make.bat @@ -1,242 +1,6 @@ @ECHO OFF -REM Command file for Sphinx documentation +rem This is the entry point for running Sphinx documentation. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -%SPHINXBUILD% 2> nul -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) - -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) - -:end +cmd /V /E /C %~dp0make2.bat %* diff --git a/python/docs/make2.bat b/python/docs/make2.bat new file mode 100644 index 0000000000000..7bcaeafad13d7 --- /dev/null +++ b/python/docs/make2.bat @@ -0,0 +1,243 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end From 18ab6bd709bb9fcae290ffc43294d13f06670d55 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 14 Oct 2014 21:37:51 -0700 Subject: [PATCH 293/315] SPARK-1307 [DOCS] Don't use term 'standalone' to refer to a Spark Application HT to Diana, just proposing an implementation of her suggestion, which I rather agreed with. Is there a second/third for the motion? Refer to "self-contained" rather than "standalone" apps to avoid confusion with standalone deployment mode. And fix placement of reference to this in MLlib docs. Author: Sean Owen Closes #2787 from srowen/SPARK-1307 and squashes the following commits: b5b82e2 [Sean Owen] Refer to "self-contained" rather than "standalone" apps to avoid confusion with standalone deployment mode. And fix placement of reference to this in MLlib docs. --- docs/mllib-clustering.md | 14 +++++++------- docs/mllib-collaborative-filtering.md | 14 +++++++------- docs/mllib-dimensionality-reduction.md | 17 +++++++++-------- docs/mllib-linear-methods.md | 20 ++++++++++---------- docs/quick-start.md | 8 ++++---- 5 files changed, 37 insertions(+), 36 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d10bd63746629..7978e934fb36b 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -69,7 +69,7 @@ println("Within Set Sum of Squared Errors = " + WSSSE) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: {% highlight java %} @@ -113,12 +113,6 @@ public class KMeansExample { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -153,3 +147,9 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
    + +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index d5c539db791be..2094963392295 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -110,7 +110,7 @@ val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -184,12 +184,6 @@ public class CollaborativeFiltering { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -229,6 +223,12 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01)
    +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 21cb35b4270ca..870fed6cc5024 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -121,9 +121,9 @@ public class SVD { The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -200,10 +200,11 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. + +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index d31bec3e1bd01..bc914a1899801 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -247,7 +247,7 @@ val modelL1 = svmAlg.run(training) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -323,9 +323,9 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -482,12 +482,6 @@ public class LinearRegression { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -519,6 +513,12 @@ print("Mean Squared Error = " + str(MSE))
    +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, diff --git a/docs/quick-start.md b/docs/quick-start.md index 23313d8aa6152..6236de0e1f2c4 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -8,7 +8,7 @@ title: Quick Start This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), -then show how to write standalone applications in Java, Scala, and Python. +then show how to write applications in Java, Scala, and Python. See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the @@ -215,8 +215,8 @@ a cluster, as described in the [programming guide](programming-guide.html#initia -# Standalone Applications -Now say we wanted to write a standalone application using the Spark API. We will walk through a +# Self-Contained Applications +Now say we wanted to write a self-contained application using the Spark API. We will walk through a simple application in both Scala (with SBT), Java (with Maven), and Python.
    @@ -387,7 +387,7 @@ Lines with a: 46, Lines with b: 23
    -Now we will show how to write a standalone application using the Python API (PySpark). +Now we will show how to write an application using the Python API (PySpark). As an example, we'll create a simple Spark application, `SimpleApp.py`: From 293a0b5dbba0474832dc7e9d387f3b10f6c452ea Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Tue, 14 Oct 2014 22:16:38 -0700 Subject: [PATCH 294/315] [SPARK-2098] All Spark processes should support spark-defaults.conf, config file This is another implementation about #1256 cc andrewor14 vanzin Author: GuoQiang Li Closes #2379 from witgo/SPARK-2098-new and squashes the following commits: 4ef1cbd [GuoQiang Li] review commit 49ef70e [GuoQiang Li] Refactor getDefaultPropertiesFile c45d20c [GuoQiang Li] All Spark processes should support spark-defaults.conf, config file --- .../spark/deploy/SparkSubmitArguments.scala | 42 ++-------------- .../SparkSubmitDriverBootstrapper.scala | 2 +- .../history/HistoryServerArguments.scala | 16 ++++++- .../spark/deploy/master/MasterArguments.scala | 19 ++++++-- .../spark/deploy/worker/WorkerArguments.scala | 21 ++++++-- .../scala/org/apache/spark/util/Utils.scala | 48 +++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 19 ++++++++ docs/monitoring.md | 7 +++ 8 files changed, 124 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 57b251ff47714..72a452e0aefb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,14 +17,11 @@ package org.apache.spark.deploy -import java.io.{File, FileInputStream, IOException} -import java.util.Properties import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.SparkException import org.apache.spark.util.Utils /** @@ -63,9 +60,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St val defaultProperties = new HashMap[String, String]() if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - val file = new File(filename) - SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) => - if (k.startsWith("spark")) { + Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + if (k.startsWith("spark.")) { defaultProperties(k) = v if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } else { @@ -90,19 +86,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St */ private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user - if (propertiesFile == null) { - val sep = File.separator - val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf") - val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig) - - confDir.foreach { sparkConfDir => - val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" - val file = new File(defaultPath) - if (file.exists()) { - propertiesFile = file.getAbsolutePath - } - } - } + propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env)) val properties = HashMap[String, String]() properties.putAll(defaultSparkProperties) @@ -397,23 +381,3 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St SparkSubmit.exitFn() } } - -object SparkSubmitArguments { - /** Load properties present in the given file. */ - def getPropertiesFromFile(file: File): Seq[(String, String)] = { - require(file.exists(), s"Properties file $file does not exist") - require(file.isFile(), s"Properties file $file is not a normal file") - val inputStream = new FileInputStream(file) - try { - val properties = new Properties() - properties.load(inputStream) - properties.stringPropertyNames().toSeq.map(k => (k, properties(k).trim)) - } catch { - case e: IOException => - val message = s"Failed when loading Spark properties file $file" - throw new SparkException(message, e) - } finally { - inputStream.close() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index a64170a47bc1c..0125330589da5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -68,7 +68,7 @@ private[spark] object SparkSubmitDriverBootstrapper { assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") // Parse the properties file for the equivalent spark.driver.* configs - val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap + val properties = Utils.getPropertiesFromFile(propertiesFile) val confDriverMemory = properties.get("spark.driver.memory") val confLibraryPath = properties.get("spark.driver.extraLibraryPath") val confClasspath = properties.get("spark.driver.extraClassPath") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 25fc76c23e0fb..5bce32a04d16d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -18,12 +18,14 @@ package org.apache.spark.deploy.history import org.apache.spark.SparkConf +import org.apache.spark.util.Utils /** * Command-line parser for the master. */ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { private var logDir: String = null + private var propertiesFile: String = null parse(args.toList) @@ -32,11 +34,16 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] case ("--dir" | "-d") :: value :: tail => logDir = value conf.set("spark.history.fs.logDirectory", value) + System.setProperty("spark.history.fs.logDirectory", value) parse(tail) case ("--help" | "-h") :: tail => printUsageAndExit(0) + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + case Nil => case _ => @@ -44,10 +51,17 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] } } + // This mutates the SparkConf, so all accesses to it must be made after this line + Utils.loadDefaultSparkProperties(conf, propertiesFile) + private def printUsageAndExit(exitCode: Int) { System.err.println( """ - |Usage: HistoryServer + |Usage: HistoryServer [options] + | + |Options: + | --properties-file FILE Path to a custom Spark properties file. + | Default is conf/spark-defaults.conf. | |Configuration options can be set by setting the corresponding JVM system property. |History Server options are always available; additional options depend on the provider. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 4b0dbbe543d3f..e34bee7854292 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -27,6 +27,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_MASTER_HOST") != null) { @@ -38,12 +39,16 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } + + parse(args.toList) + + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + if (conf.contains("spark.master.ui.port")) { webUiPort = conf.get("spark.master.ui.port").toInt } - parse(args.toList) - def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -63,7 +68,11 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case Nil => {} @@ -83,7 +92,9 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + - " --webui-port PORT Port for web UI (default: 8080)") + " --webui-port PORT Port for web UI (default: 8080)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 54e3937edde6b..019cd70f2a229 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -33,6 +33,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var memory = inferDefaultMemory() var masters: Array[String] = null var workDir: String = null + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_WORKER_PORT") != null) { @@ -47,15 +48,19 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } - if (conf.contains("spark.worker.ui.port")) { - webUiPort = conf.get("spark.worker.ui.port").toInt - } if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } parse(args.toList) + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + if (conf.contains("spark.worker.ui.port")) { + webUiPort = conf.get("spark.worker.ui.port").toInt + } + checkWorkerMemory() def parse(args: List[String]): Unit = args match { @@ -89,7 +94,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case value :: tail => @@ -124,7 +133,9 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + - " --webui-port PORT Port for web UI (default: 8081)") + " --webui-port PORT Port for web UI (default: 8081)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index aad901620f53e..cbc4095065a19 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1410,6 +1410,54 @@ private[spark] object Utils extends Logging { } } + /** + * Load default Spark properties from the given file. If no file is provided, + * use the common defaults file. This mutates state in the given SparkConf and + * in this JVM's system properties if the config specified in the file is not + * already set. Return the path of the properties file used. + */ + def loadDefaultSparkProperties(conf: SparkConf, filePath: String = null): String = { + val path = Option(filePath).getOrElse(getDefaultPropertiesFile()) + Option(path).foreach { confFile => + getPropertiesFromFile(confFile).filter { case (k, v) => + k.startsWith("spark.") + }.foreach { case (k, v) => + conf.setIfMissing(k, v) + sys.props.getOrElseUpdate(k, v) + } + } + path + } + + /** Load properties present in the given file. */ + def getPropertiesFromFile(filename: String): Map[String, String] = { + val file = new File(filename) + require(file.exists(), s"Properties file $file does not exist") + require(file.isFile(), s"Properties file $file is not a normal file") + + val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8") + try { + val properties = new Properties() + properties.load(inReader) + properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap + } catch { + case e: IOException => + throw new SparkException(s"Failed when loading Spark properties from $filename", e) + } finally { + inReader.close() + } + } + + /** Return the path of the default Spark properties file. */ + def getDefaultPropertiesFile(env: Map[String, String] = sys.env): String = { + env.get("SPARK_CONF_DIR") + .orElse(env.get("SPARK_HOME").map { t => s"$t${File.separator}conf" }) + .map { t => new File(s"$t${File.separator}spark-defaults.conf")} + .filter(_.isFile) + .map(_.getAbsolutePath) + .orNull + } + /** Return a nice string representation of the exception, including the stack trace. */ def exceptionString(e: Exception): String = { if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 0344da60dae66..ea7ef0524d1e1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -27,6 +27,8 @@ import com.google.common.base.Charsets import com.google.common.io.Files import org.scalatest.FunSuite +import org.apache.spark.SparkConf + class UtilsSuite extends FunSuite { test("bytesToString") { @@ -332,4 +334,21 @@ class UtilsSuite extends FunSuite { assert(!tempFile2.exists()) } + test("loading properties from file") { + val outFile = File.createTempFile("test-load-spark-properties", "test") + try { + System.setProperty("spark.test.fileNameLoadB", "2") + Files.write("spark.test.fileNameLoadA true\n" + + "spark.test.fileNameLoadB 1\n", outFile, Charsets.UTF_8) + val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) + properties + .filter { case (k, v) => k.startsWith("spark.")} + .foreach { case (k, v) => sys.props.getOrElseUpdate(k, v)} + val sparkConf = new SparkConf + assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true) + assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2) + } finally { + outFile.delete() + } + } } diff --git a/docs/monitoring.md b/docs/monitoring.md index d07ec4a57a2cc..e3f81a76acdbb 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -77,6 +77,13 @@ follows: one implementation, provided by Spark, which looks for application logs stored in the file system. +
    + + + + From 044583a241203e7fe759366b273ad32fd9bf7c05 Mon Sep 17 00:00:00 2001 From: prudhvi Date: Thu, 16 Oct 2014 02:05:44 -0400 Subject: [PATCH 295/315] [Core] Upgrading ScalaStyle version to 0.5 and removing SparkSpaceAfterCommentStartChecker. Author: prudhvi Closes #2799 from prudhvije/ScalaStyle/space-after-comment-start and squashes the following commits: fc263a1 [prudhvi] [Core] Using scalastyle to check the space after comment start --- project/plugins.sbt | 2 +- .../SparkSpaceAfterCommentStartChecker.scala | 58 ------------------- scalastyle-config.xml | 2 +- 3 files changed, 2 insertions(+), 60 deletions(-) delete mode 100644 project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala diff --git a/project/plugins.sbt b/project/plugins.sbt index 8096c61414660..678f5ed1ba610 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -17,7 +17,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.5.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala deleted file mode 100644 index 80d3faa3fe749..0000000000000 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.scalastyle - -import java.util.regex.Pattern - -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} -import scalariform.lexer.{MultiLineComment, ScalaDocComment, SingleLineComment, Token} -import scalariform.parser.CompilationUnit - -class SparkSpaceAfterCommentStartChecker extends ScalariformChecker { - val errorKey: String = "insert.a.single.space.after.comment.start.and.before.end" - - private def multiLineCommentRegex(comment: Token) = - Pattern.compile( """/\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def scalaDocPatternRegex(comment: Token) = - Pattern.compile( """/\*\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def singleLineCommentRegex(comment: Token): Boolean = - comment.text.trim.matches( """//\S+.*""") && !comment.text.trim.matches( """///+""") - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens - .filter(hasComment) - .map { - _.associatedWhitespaceAndComments.comments.map { - case x: SingleLineComment if singleLineCommentRegex(x.token) => Some(x.token.offset) - case x: MultiLineComment if multiLineCommentRegex(x.token) => Some(x.token.offset) - case x: ScalaDocComment if scalaDocPatternRegex(x.token) => Some(x.token.offset) - case _ => None - }.flatten - }.flatten.map(PositionError(_)) - } - - - private def hasComment(x: Token) = - x.associatedWhitespaceAndComments != null && !x.associatedWhitespaceAndComments.comments.isEmpty - -} diff --git a/scalastyle-config.xml b/scalastyle-config.xml index c54f8b72ebf42..0ff521706c71a 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -141,5 +141,5 @@ - + From 4c589cac4496c6a4bb8485a340bd0641dca13847 Mon Sep 17 00:00:00 2001 From: Shiti Date: Thu, 16 Oct 2014 10:52:06 -0700 Subject: [PATCH 296/315] [SPARK-3944][Core] Code re-factored as suggested Author: Shiti Closes #2810 from Shiti/master and squashes the following commits: 051d82f [Shiti] setting the default value of uri scheme to "file" where matching "file" or None yields the same result --- .../main/scala/org/apache/spark/util/Utils.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index cbc4095065a19..53a7512edd852 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -340,8 +340,8 @@ private[spark] object Utils extends Logging { val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) - Option(uri.getScheme) match { - case Some("http") | Some("https") | Some("ftp") => + Option(uri.getScheme).getOrElse("file") match { + case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) var uc: URLConnection = null @@ -374,7 +374,7 @@ private[spark] object Utils extends Logging { } } Files.move(tempFile, targetFile) - case Some("file") | None => + case "file" => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) @@ -403,7 +403,7 @@ private[spark] object Utils extends Logging { logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) Files.copy(sourceFile, targetFile) } - case Some(other) => + case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val fs = getHadoopFileSystem(uri, hadoopConf) val in = fs.open(new Path(uri)) @@ -1401,10 +1401,10 @@ private[spark] object Utils extends Logging { paths.split(",").filter { p => val formattedPath = if (windows) formatWindowsPath(p) else p val uri = new URI(formattedPath) - Option(uri.getScheme) match { - case Some(windowsDrive(d)) if windows => false - case Some("local") | Some("file") | None => false - case Some(other) => true + Option(uri.getScheme).getOrElse("file") match { + case windowsDrive(d) if windows => false + case "local" | "file" => false + case _ => true } } } From 091d32c52e9d73da95896016c1d920e89858abfa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 16 Oct 2014 14:56:50 -0700 Subject: [PATCH 297/315] [SPARK-3971] [MLLib] [PySpark] hotfix: Customized pickler should work in cluster mode Customized pickler should be registered before unpickling, but in executor, there is no way to register the picklers before run the tasks. So, we need to register the picklers in the tasks itself, duplicate the javaToPython() and pythonToJava() in MLlib, call SerDe.initialize() before pickling or unpickling. Author: Davies Liu Closes #2830 from davies/fix_pickle and squashes the following commits: 0c85fb9 [Davies Liu] revert the privacy change 6b94e15 [Davies Liu] use JavaConverters instead of JavaConversions 0f02050 [Davies Liu] hotfix: Customized pickler does not work in cluster --- .../apache/spark/api/python/PythonRDD.scala | 7 ++- .../apache/spark/api/python/SerDeUtil.scala | 14 ++++- .../mllib/api/python/PythonMLLibAPI.scala | 52 +++++++++++++++++-- python/pyspark/context.py | 2 - python/pyspark/mllib/classification.py | 4 +- python/pyspark/mllib/clustering.py | 4 +- python/pyspark/mllib/feature.py | 5 +- python/pyspark/mllib/linalg.py | 13 +++++ python/pyspark/mllib/random.py | 2 +- python/pyspark/mllib/recommendation.py | 7 +-- python/pyspark/mllib/regression.py | 4 +- python/pyspark/mllib/stat.py | 7 +-- python/pyspark/mllib/tree.py | 8 +-- python/pyspark/mllib/util.py | 6 +-- 14 files changed, 101 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4acbdf9d5e25f..29ca751519abd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,6 +23,7 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials @@ -746,6 +747,7 @@ private[spark] object PythonRDD extends Logging { def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler + SerDeUtil.initialize() iter.flatMap { row => unpickle.loads(row) match { // in case of objects are pickled in batch mode @@ -785,7 +787,7 @@ private[spark] object PythonRDD extends Logging { }.toJavaRDD() } - private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { private val pickle = new Pickler() private var batch = 1 private val buffer = new mutable.ArrayBuffer[Any] @@ -822,11 +824,12 @@ private[spark] object PythonRDD extends Logging { */ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { pyRDD.rdd.mapPartitions { iter => + SerDeUtil.initialize() val unpickle = new Unpickler iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]] + obj.asInstanceOf[JArrayList[_]].asScala } else { Seq(obj) } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 7903457b17e13..ebdc3533e0992 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ -private[python] object SerDeUtil extends Logging { +private[spark] object SerDeUtil extends Logging { // Unpickle array.array generated by Python 2.6 class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { // /* Description of types */ @@ -76,9 +76,18 @@ private[python] object SerDeUtil extends Logging { } } + private var initialized = false + // This should be called before trying to unpickle array.array from Python + // In cluster mode, this should be put in closure def initialize() = { - Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + synchronized{ + if (!initialized) { + Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + initialized = true + } + } } + initialize() private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler @@ -143,6 +152,7 @@ private[python] object SerDeUtil extends Logging { obj.asInstanceOf[Array[_]].length == 2 } pyRDD.mapPartitions { iter => + initialize() val unpickle = new Unpickler val unpickled = if (batchSerialized) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f7251e65e04f1..9a100170b75c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream +import java.util.{ArrayList => JArrayList} import scala.collection.JavaConverters._ import scala.language.existentials @@ -27,6 +28,7 @@ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature.Word2Vec @@ -639,13 +641,24 @@ private[spark] object SerDe extends Serializable { } } + var initialized = false + // This should be called before trying to serialize any above classes + // In cluster mode, this should be put in the closure def initialize(): Unit = { - new DenseVectorPickler().register() - new DenseMatrixPickler().register() - new SparseVectorPickler().register() - new LabeledPointPickler().register() - new RatingPickler().register() + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new DenseVectorPickler().register() + new DenseMatrixPickler().register() + new SparseVectorPickler().register() + new LabeledPointPickler().register() + new RatingPickler().register() + initialized = true + } + } } + // will not called in Executor automatically + initialize() def dumps(obj: AnyRef): Array[Byte] = { new Pickler().dumps(obj) @@ -659,4 +672,33 @@ private[spark] object SerDe extends Serializable { def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = { rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + new PythonRDD.AutoBatchedPickler(iter) + } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 89d2e2e5b4a8e..8d27ccb95f82c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -215,8 +215,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile - SparkContext._jvm.SerDeUtil.initialize() - SparkContext._jvm.SerDe.initialize() if instance: if (SparkContext._active_spark_context and diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cd43982191702..e295c9d0954d9 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,7 +21,7 @@ from numpy import array from pyspark import SparkContext, PickleSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -244,7 +244,7 @@ def train(cls, data, lambda_=1.0): :param lambda_: The smoothing parameter """ sc = data.context - jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) + jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_) labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist))) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 12c56022717a5..5ee7997104d21 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -17,7 +17,7 @@ from pyspark import SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd __all__ = ['KMeansModel', 'KMeans'] @@ -85,7 +85,7 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" # cache serialized data to avoid objects over head in JVM cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() model = sc._jvm.PythonMLLibAPI().trainKMeansModel( - cached._to_java_object_rdd(), k, maxIterations, runs, initializationMode) + _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode) bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) centers = ser.loads(str(bytes)) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f4cbf31b94fe2..b5a3f22c6907e 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -19,8 +19,7 @@ Python package for feature in MLlib. """ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer - -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd __all__ = ['Word2Vec', 'Word2VecModel'] @@ -176,7 +175,7 @@ def fit(self, data): seed = self.seed model = sc._jvm.PythonMLLibAPI().trainWord2Vec( - data._to_java_object_rdd(), vectorSize, + _to_java_object_rdd(data), vectorSize, learningRate, numPartitions, numIterations, seed) return Word2VecModel(sc, model) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 24c5480b2f753..773d8d393805d 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,6 +29,8 @@ import numpy as np +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -50,6 +52,17 @@ def fast_pickle_array(ar): _have_scipy = False +# this will call the MLlib version of pythonToJava() +def _to_java_object_rdd(rdd): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) + + def _convert_to_vector(l): if isinstance(l, Vector): return l diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index a787e4dea2c55..73baba4ace5f6 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -32,7 +32,7 @@ def serialize(f): @wraps(f) def func(sc, *a, **kw): jrdd = f(sc, *a, **kw) - return RDD(sc._jvm.PythonRDD.javaToPython(jrdd), sc, + return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc, BatchedSerializer(PickleSerializer(), 1024)) return func diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 59c1c5ff0ced0..17f96b8700bd7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,6 +18,7 @@ from pyspark import SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.rdd import RDD +from pyspark.mllib.linalg import _to_java_object_rdd __all__ = ['MatrixFactorizationModel', 'ALS'] @@ -77,9 +78,9 @@ def predictAll(self, user_product): first = tuple(map(int, first)) assert all(type(x) is int for x in first), "user and product in user_product shoul be int" sc = self._context - tuplerdd = sc._jvm.SerDe.asTupleRDD(user_product._to_java_object_rdd().rdd()) + tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd()) jresult = self._java_model.predict(tuplerdd).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(jresult), sc, + return RDD(sc._jvm.SerDe.javaToPython(jresult), sc, AutoBatchedSerializer(PickleSerializer())) @@ -97,7 +98,7 @@ def _prepare(cls, ratings): # serialize them by AutoBatchedSerializer before cache to reduce the # objects overhead in JVM cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache() - return cached._to_java_object_rdd() + return _to_java_object_rdd(cached) @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 12b322aaae796..93e17faf5cd51 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -19,8 +19,8 @@ from numpy import array from pyspark import SparkContext -from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -131,7 +131,7 @@ def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights) # use AutoBatchedSerializer before cache to reduce the memory # overhead in JVM cached = data._reserialize(AutoBatchedSerializer(ser)).cache() - ans = train_func(cached._to_java_object_rdd(), initial_bytes) + ans = train_func(_to_java_object_rdd(cached), initial_bytes) assert len(ans) == 2, "JVM call result had unexpected length" weights = ser.loads(str(ans[0])) return modelClass(weights, ans[1]) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index b9de0909a6fb1..a6019dadf781c 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,6 +22,7 @@ from functools import wraps from pyspark import PickleSerializer +from pyspark.mllib.linalg import _to_java_object_rdd __all__ = ['MultivariateStatisticalSummary', 'Statistics'] @@ -106,7 +107,7 @@ def colStats(rdd): array([ 2., 0., 0., -2.]) """ sc = rdd.ctx - jrdd = rdd._to_java_object_rdd() + jrdd = _to_java_object_rdd(rdd) cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) return MultivariateStatisticalSummary(sc, cStats) @@ -162,14 +163,14 @@ def corr(x, y=None, method=None): if type(y) == str: raise TypeError("Use 'method=' to specify method name.") - jx = x._to_java_object_rdd() + jx = _to_java_object_rdd(x) if not y: resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) bytes = sc._jvm.SerDe.dumps(resultMat) ser = PickleSerializer() return ser.loads(str(bytes)).toArray() else: - jy = y._to_java_object_rdd() + jy = _to_java_object_rdd(y) return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d7abfb96b7fe..0938eebd3a548 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ from pyspark import SparkContext, RDD from pyspark.serializers import BatchedSerializer, PickleSerializer -from pyspark.mllib.linalg import Vector, _convert_to_vector +from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd from pyspark.mllib.regression import LabeledPoint __all__ = ['DecisionTreeModel', 'DecisionTree'] @@ -61,8 +61,8 @@ def predict(self, x): return self._sc.parallelize([]) if not isinstance(first[0], Vector): x = x.map(_convert_to_vector) - jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD() - jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred) + jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD() + jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred) return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) else: @@ -104,7 +104,7 @@ def _train(data, type, numClasses, categoricalFeaturesInfo, first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" sc = data.context - jrdd = data._to_java_object_rdd() + jrdd = _to_java_object_rdd(data) cfiMap = MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 1357fd4fbc8aa..84b39a48619d2 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -19,7 +19,7 @@ import warnings from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -174,8 +174,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) - jpyrdd = sc._jvm.PythonRDD.javaToPython(jrdd) - return RDD(jpyrdd, sc, BatchedSerializer(PickleSerializer())) + jpyrdd = sc._jvm.SerDe.javaToPython(jrdd) + return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer())) def _test(): From 99e416b6d64402a5432a265797a1c155a38f4e6f Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 16 Oct 2014 16:15:55 -0700 Subject: [PATCH 298/315] [SQL] Fixes the race condition that may cause test failure The removed `Future` was used to end the test case as soon as the Spark SQL CLI process exits. When the process exits prematurely, this mechanism prevents the test case to wait until timeout. But it also creates a race condition: when `foundAllExpectedAnswers.tryFailure` is called, there are chances that the last expected output line of the CLI process hasn't been caught by the main logics of the test code, thus fails the test case. Removing this `Future` doesn't affect correctness. Author: Cheng Lian Closes #2823 from liancheng/clean-clisuite and squashes the following commits: 489a97c [Cheng Lian] Fixes the race condition that may cause test failure --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index fc97a25be34be..8a72e9d2aef57 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -78,12 +78,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val process = (Process(command) #< queryStream).run( ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - Future { - val exitValue = process.exitValue() - foundAllExpectedAnswers.tryFailure( - new SparkException(s"Spark SQL CLI process exit value: $exitValue")) - } - try { Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => From 2fe0ba95616bb3860736b6b426635a5d2a0e9bd9 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 16 Oct 2014 21:38:45 -0400 Subject: [PATCH 299/315] SPARK-3874: Provide stable TaskContext API This is a small number of clean-up changes on top of #2782. Closes #2782. Author: Prashant Sharma Author: Patrick Wendell Closes #2803 from pwendell/pr-2782 and squashes the following commits: 56d5b7a [Patrick Wendell] Minor clean-up 44089ec [Patrick Wendell] Clean-up the TaskContext API. ed551ce [Prashant Sharma] Fixed a typo df261d0 [Prashant Sharma] Josh's suggestion facf3b1 [Prashant Sharma] Fixed the mima issue. 7ecc2fe [Prashant Sharma] CR, Moved implementations to TaskContextImpl bbd9e05 [Prashant Sharma] adding missed out files to git. ef633f5 [Prashant Sharma] SPARK-3874, Provide stable TaskContext API --- .../java/org/apache/spark/TaskContext.java | 225 +++--------------- .../org/apache/spark/TaskContextHelper.scala | 29 +++ .../org/apache/spark/TaskContextImpl.scala | 91 +++++++ .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 8 +- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../org/apache/spark/scheduler/Task.scala | 10 +- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../util/JavaTaskCompletionListenerImpl.java | 4 +- .../org/apache/spark/CacheManagerSuite.scala | 8 +- .../org/apache/spark/rdd/PipedRDDSuite.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 8 +- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 6 +- .../sql/parquet/ParquetTableOperations.scala | 4 +- 16 files changed, 186 insertions(+), 223 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/TaskContextHelper.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskContextImpl.scala diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 4e6d708af0ea7..2d998d4c7a5d9 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -18,131 +18,55 @@ package org.apache.spark; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import scala.Function0; import scala.Function1; import scala.Unit; -import scala.collection.JavaConversions; import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.util.TaskCompletionListener; -import org.apache.spark.util.TaskCompletionListenerException; /** -* :: DeveloperApi :: -* Contextual information about a task which can be read or mutated during execution. -*/ -@DeveloperApi -public class TaskContext implements Serializable { - - private int stageId; - private int partitionId; - private long attemptId; - private boolean runningLocally; - private TaskMetrics taskMetrics; - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, - TaskMetrics taskMetrics) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = taskMetrics; - } - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - } - + * Contextual information about a task which can be read or mutated during + * execution. To access the TaskContext for a running task use + * TaskContext.get(). + */ +public abstract class TaskContext implements Serializable { /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task + * Return the currently active TaskContext. This can be called inside of + * user functions to access contextual information about running tasks. */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = false; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); + public static TaskContext get() { + return taskContext.get(); } private static ThreadLocal taskContext = new ThreadLocal(); - /** - * :: Internal API :: - * This is spark internal API, not intended to be called from user programs. - */ - public static void setTaskContext(TaskContext tc) { + static void setTaskContext(TaskContext tc) { taskContext.set(tc); } - public static TaskContext get() { - return taskContext.get(); - } - - /** :: Internal API :: */ - public static void unset() { + static void unset() { taskContext.remove(); } - // List of callback functions to execute when the task completes. - private transient List onCompleteCallbacks = - new ArrayList(); - - // Whether the corresponding task has been killed. - private volatile boolean interrupted = false; - - // Whether the task has completed. - private volatile boolean completed = false; - /** - * Checks whether the task has completed. + * Whether the task has completed. */ - public boolean isCompleted() { - return completed; - } + public abstract boolean isCompleted(); /** - * Checks whether the task has been killed. + * Whether the task has been killed. */ - public boolean isInterrupted() { - return interrupted; - } + public abstract boolean isInterrupted(); + + /** @deprecated: use isRunningLocally() */ + @Deprecated + public abstract boolean runningLocally(); + + public abstract boolean isRunningLocally(); /** * Add a (Java friendly) listener to be executed on task completion. @@ -150,10 +74,7 @@ public boolean isInterrupted() { *

    * An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { - onCompleteCallbacks.add(listener); - return this; - } + public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); /** * Add a listener in the form of a Scala closure to be executed on task completion. @@ -161,109 +82,27 @@ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { *

    * An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(final Function1 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(context); - } - }); - return this; - } + public abstract TaskContext addTaskCompletionListener(final Function1 f); /** * Add a callback function to be executed on task completion. An example use * is for HadoopRDD to register a callback to close the input stream. * Will be called in any situation - success, failure, or cancellation. * - * Deprecated: use addTaskCompletionListener - * + * @deprecated: use addTaskCompletionListener + * * @param f Callback function. */ @Deprecated - public void addOnCompleteCallback(final Function0 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(); - } - }); - } - - /** - * ::Internal API:: - * Marks the task as completed and triggers the listeners. - */ - public void markTaskCompleted() throws TaskCompletionListenerException { - completed = true; - List errorMsgs = new ArrayList(2); - // Process complete callbacks in the reverse order of registration - List revlist = - new ArrayList(onCompleteCallbacks); - Collections.reverse(revlist); - for (TaskCompletionListener tcl: revlist) { - try { - tcl.onTaskCompletion(this); - } catch (Throwable e) { - errorMsgs.add(e.getMessage()); - } - } - - if (!errorMsgs.isEmpty()) { - throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); - } - } - - /** - * ::Internal API:: - * Marks the task for interruption, i.e. cancellation. - */ - public void markInterrupted() { - interrupted = true; - } - - @Deprecated - /** Deprecated: use getStageId() */ - public int stageId() { - return stageId; - } - - @Deprecated - /** Deprecated: use getPartitionId() */ - public int partitionId() { - return partitionId; - } - - @Deprecated - /** Deprecated: use getAttemptId() */ - public long attemptId() { - return attemptId; - } - - @Deprecated - /** Deprecated: use isRunningLocally() */ - public boolean runningLocally() { - return runningLocally; - } - - public boolean isRunningLocally() { - return runningLocally; - } + public abstract void addOnCompleteCallback(final Function0 f); - public int getStageId() { - return stageId; - } + public abstract int stageId(); - public int getPartitionId() { - return partitionId; - } + public abstract int partitionId(); - public long getAttemptId() { - return attemptId; - } + public abstract long attemptId(); - /** ::Internal API:: */ - public TaskMetrics taskMetrics() { - return taskMetrics; - } + /** ::DeveloperApi:: */ + @DeveloperApi + public abstract TaskMetrics taskMetrics(); } diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala new file mode 100644 index 0000000000000..4636c4600a01a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This class exists to restrict the visibility of TaskContext setters. + */ +private [spark] object TaskContextHelper { + + def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) + + def unset(): Unit = TaskContext.unset() + +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala new file mode 100644 index 0000000000000..afd2b85d33a77 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} + +import scala.collection.mutable.ArrayBuffer + +private[spark] class TaskContextImpl(val stageId: Int, + val partitionId: Int, + val attemptId: Long, + val runningLocally: Boolean = false, + val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContext + with Logging { + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + + // Whether the corresponding task has been killed. + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } + + @deprecated("use addTaskCompletionListener", "1.1.0") + override def addOnCompleteCallback(f: () => Unit) { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } + } + + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { + completed = true + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true + } + + override def isCompleted: Boolean = completed + + override def isRunningLocally: Boolean = runningLocally + + override def isInterrupted: Boolean = interrupted +} + diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6b63eb23e9ee1..8010dd90082f8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -196,7 +196,7 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0d97506450a7f..929ded58a3bd5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -956,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outfmt.newInstance @@ -1027,9 +1027,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.getStageId, context.getPartitionId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { var count = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 788eb1ff4e455..f81fa6d8089fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -633,14 +633,14 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, true) - TaskContext.setTaskContext(taskContext) + new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true) + TaskContextHelper.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c6e47c84a0cb2..2552d03d18d06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, false) - TaskContext.setTaskContext(context) + context = new TaskContextImpl(stageId, partitionId, attemptId, false) + TaskContextHelper.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } @@ -70,7 +70,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex var metrics: Option[TaskMetrics] = None // Task context, to be initialized in run(). - @transient protected var context: TaskContext = _ + @transient protected var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 4a078435447e5..b8fa822ae4bd8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -776,7 +776,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index 0944bf8cd5c71..e9ec700e32e15 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,8 +30,8 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.getStageId(); - context.getPartitionId(); + context.stageId(); + context.partitionId(); context.isRunningLocally(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d735010d7c9d5..c0735f448d193 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index be972c5e97a7e..271a90c6646bb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContext(0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index faba5508c906c..561a5e9cd90c4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 809bd70929656..a8c049d749015 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} import org.mockito.Mockito._ @@ -62,7 +62,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -120,7 +120,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -169,7 +169,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { (bmId, Seq((blId1, 1L), (blId2, 1L)))) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 39f8ba4745737..d919b18e09855 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -32,7 +32,7 @@ object MimaBuild { ProblemFilters.exclude[MissingMethodProblem](fullName), // Sometimes excluded methods have default arguments and // they are translated into public methods/fields($default$) in generated - // bytecode. It is not possible to exhustively list everything. + // bytecode. It is not possible to exhaustively list everything. // But this should be okay. ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"), ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d499302124461..350aad47735e4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,7 +50,11 @@ object MimaExcludes { "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus") + "org.apache.spark.scheduler.MapStatus"), + // TaskContext was promoted to Abstract class + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.TaskContext") + ) case v if v.startsWith("1.1") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1f4237d7ede65..5c6fa78ae3895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -289,9 +289,9 @@ case class InsertIntoParquetTable( def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) From 7f7b50ed9d4ffdd6b23e0faa56b068a049da67f7 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 16 Oct 2014 18:58:18 -0700 Subject: [PATCH 300/315] [SPARK-3923] Increase Akka heartbeat pause above heartbeat interval Something about the 2.3.4 upgrade seems to have made the issue manifest where all the services disconnect from each other after exactly 1000 seconds (which is the heartbeat interval). [This post](https://groups.google.com/forum/#!topic/akka-user/X3xzpTCbEFs) suggests that heartbeat pause should be greater than heartbeat interval, and increasing the pause from 600s to 6000s seems to have rectified the issue. My current cluster has now exceeded 1400s of uptime without failure! I do not know why this fixed it, because the threshold we have set for the failure detector is the exponent of a timeout, and 300 is extremely large. Perhaps the default failure detector changed in 2.3.4 and now ignores threshold. Author: Aaron Davidson Closes #2784 from aarondav/fix-timeout and squashes the following commits: bd1151a [Aaron Davidson] Increase pause, don't decrease interval 9cb0372 [Aaron Davidson] [SPARK-3923] Decrease Akka heartbeat interval below heartbeat pause --- core/src/main/scala/org/apache/spark/util/AkkaUtils.scala | 2 +- docs/configuration.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index e2d32c859bbda..f41c8d0315cb3 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -77,7 +77,7 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600) + val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) val akkaFailureDetector = conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) diff --git a/docs/configuration.md b/docs/configuration.md index f311f0d2a6206..8515ee045177f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -725,7 +725,7 @@ Apart from these, the following properties are also available, and may be useful

    - +
    StructType org.apache.spark.sql.Row + StructType(fields)
    Note: fields is a Seq of StructFields. Also, two fields with the same name are not allowed. @@ -1267,7 +1289,7 @@ import org.apache.spark.sql._ All data types of Spark SQL are located in the package of `org.apache.spark.sql.api.java`. To access or create a data type, -please use factory methods provided in +please use factory methods provided in `org.apache.spark.sql.api.java.DataType`. @@ -1373,7 +1395,7 @@ please use factory methods provided in - - - + From 7d1a37239c50394025d9f16acf5dcd05cfbe7250 Mon Sep 17 00:00:00 2001 From: chesterxgchen Date: Wed, 17 Sep 2014 10:25:52 -0500 Subject: [PATCH 017/315] SPARK-3177 (on Master Branch) The JIRA and PR was original created for branch-1.1, and move to master branch now. Chester The Issue is due to that yarn-alpha and yarn have different APIs for certain class fields. In this particular case, the ClientBase using reflection to to address this issue, and we need to different way to test the ClientBase's method. Original ClientBaseSuite using getFieldValue() method to do this. But it doesn't work for yarn-alpha as the API returns an array of String instead of just String (which is the case for Yarn-stable API). To fix the test, I add a new method def getFieldValue2[A: ClassTag, A1: ClassTag, B](clazz: Class[_], field: String, defaults: => B) (mapTo: A => B)(mapTo1: A1 => B) : B = Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) case _ => defaults }.toOption.getOrElse(defaults) to handle the cases where the field type can be either type A or A1. In this new method the type A or A1 is pattern matched and corresponding mapTo function (mapTo or mapTo1) is used. Author: chesterxgchen Closes #2204 from chesterxgchen/SPARK-3177-master and squashes the following commits: e72a6ea [chesterxgchen] The Issue is due to that yarn-alpha and yarn have different APIs for certain class fields. In this particular case, the ClientBase using reflection to to address this issue, and we need to different way to test the ClientBase's method. Original ClientBaseSuite using getFieldValue() method to do this. But it doesn't work for yarn-alpha as the API returns an array of String instead of just String (which is the case for Yarn-stable API). --- .../spark/deploy/yarn/ClientBaseSuite.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 5480eca7c832c..c3b7a2c8f02e5 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -38,6 +38,7 @@ import org.scalatest.Matchers import scala.collection.JavaConversions._ import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.reflect.ClassTag import scala.util.Try import org.apache.spark.{SparkException, SparkConf} @@ -200,9 +201,10 @@ class ClientBaseSuite extends FunSuite with Matchers { val knownDefMRAppCP: Seq[String] = - getFieldValue[String, Seq[String]](classOf[MRJobConfig], - "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH", - Seq[String]())(a => a.split(",")) + getFieldValue2[String, Array[String], Seq[String]]( + classOf[MRJobConfig], + "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH", + Seq[String]())(a => a.split(","))(a => a.toSeq) val knownYARNAppCP = Some(Seq("/known/yarn/path")) @@ -232,6 +234,17 @@ class ClientBaseSuite extends FunSuite with Matchers { def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + def getFieldValue2[A: ClassTag, A1: ClassTag, B]( + clazz: Class[_], + field: String, + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B) : B = { + Try(clazz.getField(field)).map(_.get(null)).map { + case v: A => mapTo(v) + case v1: A1 => mapTo1(v1) + case _ => defaults + }.toOption.getOrElse(defaults) + } + private class DummyClient( val args: ClientArguments, val conf: Configuration, From 8fbd5f4a90f92e064aa057adbd3f8c58dd0087fa Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 17 Sep 2014 12:33:09 -0700 Subject: [PATCH 018/315] [Docs] minor grammar fix Author: Nicholas Chammas Closes #2430 from nchammas/patch-2 and squashes the following commits: d476bfb [Nicholas Chammas] [Docs] minor grammar fix --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c6b4aa5344757..b6c6b050fa331 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,5 +8,5 @@ submitting any copyrighted material via pull request, email, or other means you agree to license the material under the project's open source license and warrant that you have the legal authority to do so. -Please see [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) +Please see the [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) for more information. From cbf983bb4a550ff26756ed7308fb03db42cffcff Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 17 Sep 2014 12:41:49 -0700 Subject: [PATCH 019/315] [SQL][DOCS] Improve table caching section Author: Michael Armbrust Closes #2434 from marmbrus/patch-1 and squashes the following commits: 67215be [Michael Armbrust] [SQL][DOCS] Improve table caching section --- docs/sql-programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c498b41c43380..5212e19c41349 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -801,12 +801,12 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. +Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize -memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. +memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. -Note that if you call `cache` rather than `cacheTable`, tables will _not_ be cached using -the in-memory columnar format, and therefore `cacheTable` is strongly recommended for this use case. +Note that if you call `schemaRDD.cache()` rather than `sqlContext.cacheTable(...)`, tables will _not_ be cached using +the in-memory columnar format, and therefore `sqlContext.cacheTable(...)` is strongly recommended for this use case. Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running `SET key=value` commands using SQL. From 5044e4953a1744593d83fe90628fb4893e5463f1 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 17 Sep 2014 12:44:44 -0700 Subject: [PATCH 020/315] [SPARK-1455] [SPARK-3534] [Build] When possible, run SQL tests only. If the only files changed are related to SQL, then only run the SQL tests. This patch includes some cosmetic/maintainability refactoring. I would be more than happy to undo some of these changes if they are inappropriate. We can accept this patch mostly as-is and address the immediate need documented in [SPARK-3534](https://issues.apache.org/jira/browse/SPARK-3534), or we can keep it open until a satisfactory solution along the lines [discussed here](https://issues.apache.org/jira/browse/SPARK-1455?focusedCommentId=14136424&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-14136424) is reached. Note: I had to hack this patch up to test it locally, so what I'm submitting here and what I tested are technically different. Author: Nicholas Chammas Closes #2420 from nchammas/selective-testing and squashes the following commits: db3fa2d [Nicholas Chammas] diff against master! f9e23f6 [Nicholas Chammas] when possible, run SQL tests only --- dev/run-tests | 156 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 106 insertions(+), 50 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 79401213a7fa2..53148d23f385f 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,44 +21,73 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then - if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=1.0.4" - elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" - elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" - elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" +# Remove work directory +rm -rf ./work + +# Build against the right verison of Hadoop. +{ + if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then + if [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop1.0" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=1.0.4" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" + elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" + fi fi -fi -if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" -fi + if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" + fi +} export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" -echo "SBT_MAVEN_PROFILES_ARGS=\"$SBT_MAVEN_PROFILES_ARGS\"" - -# Remove work directory -rm -rf ./work - -if test -x "$JAVA_HOME/bin/java"; then - declare java_cmd="$JAVA_HOME/bin/java" -else - declare java_cmd=java -fi -JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') -[ "$JAVA_VERSION" -ge 18 ] && echo "" || echo "[Warn] Java 8 tests will not run because JDK version is < 1.8." +# Determine Java path and version. +{ + if test -x "$JAVA_HOME/bin/java"; then + declare java_cmd="$JAVA_HOME/bin/java" + else + declare java_cmd=java + fi + + # We can't use sed -r -e due to OS X / BSD compatibility; hence, all the parentheses. + JAVA_VERSION=$( + $java_cmd -version 2>&1 \ + | grep -e "^java version" --max-count=1 \ + | sed "s/java version \"\(.*\)\.\(.*\)\.\(.*\)\"/\1\2/" + ) + + if [ "$JAVA_VERSION" -lt 18 ]; then + echo "[warn] Java 8 tests will not run because JDK version is < 1.8." + fi +} -# Partial solution for SPARK-1455. Only run Hive tests if there are sql changes. +# Only run Hive tests if there are sql changes. +# Partial solution for SPARK-1455. if [ -n "$AMPLAB_JENKINS" ]; then git fetch origin master:master - diffs=`git diff --name-only master | grep "^\(sql/\)\|\(bin/spark-sql\)\|\(sbin/start-thriftserver.sh\)"` - if [ -n "$diffs" ]; then - echo "Detected changes in SQL. Will run Hive test suite." + + sql_diffs=$( + git diff --name-only master \ + | grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" + ) + + non_sql_diffs=$( + git diff --name-only master \ + | grep -v -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh" + ) + + if [ -n "$sql_diffs" ]; then + echo "[info] Detected changes in SQL. Will run Hive test suite." _RUN_SQL_TESTS=true + + if [ -z "$non_sql_diffs" ]; then + echo "[info] Detected no changes except in SQL. Will only run SQL tests." + _SQL_TESTS_ONLY=true + fi fi fi @@ -70,42 +99,69 @@ echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" -dev/check-license +./dev/check-license echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" -dev/lint-scala +./dev/lint-scala echo "" echo "=========================================================================" echo "Running Python style checks" echo "=========================================================================" -dev/lint-python +./dev/lint-python + +echo "" +echo "=========================================================================" +echo "Building Spark" +echo "=========================================================================" + +{ + # We always build with Hive because the PySpark Spark SQL tests need it. + BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" + + echo "[info] Building Spark with these arguments: $BUILD_MVN_PROFILE_ARGS" + + # NOTE: echo "q" is needed because sbt on encountering a build file with failure + #+ (either resolution or compilation) prompts the user for input either q, r, etc + #+ to quit or retry. This echo is there to make it not block. + # QUESTION: Why doesn't 'yes "q"' work? + # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? + echo -e "q\n" \ + | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly \ + | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" +} echo "" echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" -# Build Spark; we always build with Hive because the PySpark Spark SQL tests need it. -# echo "q" is needed because sbt on encountering a build file with failure -# (either resolution or compilation) prompts the user for input either q, r, -# etc to quit or retry. This echo is there to make it not block. -BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive " -echo -e "q\n" | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly | \ - grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" - -# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled: -if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" -fi -# echo "q" is needed because sbt on encountering a build file with failure -# (either resolution or compilation) prompts the user for input either q, r, -# etc to quit or retry. This echo is there to make it not block. -echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS test | \ - grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" +{ + # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. + if [ -n "$_RUN_SQL_TESTS" ]; then + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" + fi + + if [ -n "$_SQL_TESTS_ONLY" ]; then + SBT_MAVEN_TEST_ARGS="catalyst/test sql/test hive/test" + else + SBT_MAVEN_TEST_ARGS="test" + fi + + echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS $SBT_MAVEN_TEST_ARGS" + + # NOTE: echo "q" is needed because sbt on encountering a build file with failure + #+ (either resolution or compilation) prompts the user for input either q, r, etc + #+ to quit or retry. This echo is there to make it not block. + # QUESTION: Why doesn't 'yes "q"' work? + # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? + echo -e "q\n" \ + | sbt/sbt "$SBT_MAVEN_PROFILES_ARGS" "$SBT_MAVEN_TEST_ARGS" \ + | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" +} echo "" echo "=========================================================================" @@ -117,4 +173,4 @@ echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" -dev/mima +./dev/mima From b3830b28f8a70224d87c89d8491c514c4c191d23 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 17 Sep 2014 15:07:57 -0700 Subject: [PATCH 021/315] Docs: move HA subsections to a deeper indentation level Makes the table of contents read better Author: Andrew Ash Closes #2402 from ash211/docs/better-indentation and squashes the following commits: ea0e130 [Andrew Ash] Move HA subsections to a deeper indentation level --- docs/spark-standalone.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index c791c81f8bfd0..99a8e43a6b489 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -307,7 +307,7 @@ tight firewall settings. For a complete list of ports to configure, see the By default, standalone scheduling clusters are resilient to Worker failures (insofar as Spark itself is resilient to losing work by moving it to other workers). However, the scheduler uses a Master to make scheduling decisions, and this (by default) creates a single point of failure: if the Master crashes, no new applications can be created. In order to circumvent this, we have two high availability schemes, detailed below. -# Standby Masters with ZooKeeper +## Standby Masters with ZooKeeper **Overview** @@ -347,7 +347,7 @@ There's an important distinction to be made between "registering with a Master" Due to this property, new Masters can be created at any time, and the only thing you need to worry about is that _new_ applications and Workers can find it to register with in case it becomes the leader. Once registered, you're taken care of. -# Single-Node Recovery with Local File System +## Single-Node Recovery with Local File System **Overview** From 7fc3bb7c88a6bf5348d52ffee37a220a47c5a398 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 17 Sep 2014 15:14:04 -0700 Subject: [PATCH 022/315] [SPARK-3534] Fix expansion of testing arguments to sbt Testing arguments to `sbt` need to be passed as an array, not a single, long string. Fixes a bug introduced in #2420. Author: Nicholas Chammas Closes #2437 from nchammas/selective-testing and squashes the following commits: a9f9c1c [Nicholas Chammas] fix printing of sbt test arguments cf57cbf [Nicholas Chammas] fix sbt test arguments e33b978 [Nicholas Chammas] Merge pull request #2 from apache/master 0b47ca4 [Nicholas Chammas] Merge branch 'master' of github.com:nchammas/spark 8051486 [Nicholas Chammas] Merge pull request #1 from apache/master 03180a4 [Nicholas Chammas] Merge branch 'master' of github.com:nchammas/spark d4c5f43 [Nicholas Chammas] Merge pull request #6 from apache/master --- dev/run-tests | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 53148d23f385f..7c002160c3a4a 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -141,17 +141,20 @@ echo "=========================================================================" { # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. + # This must be a single argument, as it is. if [ -n "$_RUN_SQL_TESTS" ]; then SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" fi if [ -n "$_SQL_TESTS_ONLY" ]; then - SBT_MAVEN_TEST_ARGS="catalyst/test sql/test hive/test" + # This must be an array of individual arguments. Otherwise, having one long string + #+ will be interpreted as a single test, which doesn't work. + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test") else - SBT_MAVEN_TEST_ARGS="test" + SBT_MAVEN_TEST_ARGS=("test") fi - echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS $SBT_MAVEN_TEST_ARGS" + echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}" # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc @@ -159,7 +162,7 @@ echo "=========================================================================" # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? echo -e "q\n" \ - | sbt/sbt "$SBT_MAVEN_PROFILES_ARGS" "$SBT_MAVEN_TEST_ARGS" \ + | sbt/sbt "$SBT_MAVEN_PROFILES_ARGS" "${SBT_MAVEN_TEST_ARGS[@]}" \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } From cbc065039f5176acc49899462bfab2521da26701 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 17 Sep 2014 16:23:50 -0700 Subject: [PATCH 023/315] [SPARK-3571] Spark standalone cluster mode doesn't work. I think, this issue is caused by #1106 Author: Kousuke Saruta Closes #2436 from sarutak/SPARK-3571 and squashes the following commits: 7a4deea [Kousuke Saruta] Modified Master.scala to use numWorkersVisited and numWorkersAlive instead of stopPos 4e51e35 [Kousuke Saruta] Modified Master to prevent from 0 divide 4817ecd [Kousuke Saruta] Brushed up previous change 71e84b6 [Kousuke Saruta] Modified Master to enable schedule normally --- .../scala/org/apache/spark/deploy/master/Master.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2a3bd6ba0b9dc..432b552c58cd8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -489,23 +489,24 @@ private[spark] class Master( // First schedule drivers, they take strict precedence over applications // Randomization helps balance drivers val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) - val aliveWorkerNum = shuffledAliveWorkers.size + val numWorkersAlive = shuffledAliveWorkers.size var curPos = 0 + for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers // We assign workers to each waiting driver in a round-robin fashion. For each driver, we // start from the last worker that was assigned a driver, and continue onwards until we have // explored all alive workers. - curPos = (curPos + 1) % aliveWorkerNum - val startPos = curPos var launched = false - while (curPos != startPos && !launched) { + var numWorkersVisited = 0 + while (numWorkersVisited < numWorkersAlive && !launched) { val worker = shuffledAliveWorkers(curPos) + numWorkersVisited += 1 if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { launchDriver(worker, driver) waitingDrivers -= driver launched = true } - curPos = (curPos + 1) % aliveWorkerNum + curPos = (curPos + 1) % numWorkersAlive } } From 6688a266f2cb84c2d43b8e4d27f710718c4cc4a0 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 17 Sep 2014 16:31:58 -0700 Subject: [PATCH 024/315] [SPARK-3564][WebUI] Display App ID on HistoryPage Author: Kousuke Saruta Closes #2424 from sarutak/display-appid-on-webui and squashes the following commits: 417fe90 [Kousuke Saruta] Added "App ID column" to HistoryPage --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index c4ef8b63b0071..d25c29113d6da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -67,6 +67,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { } private val appHeader = Seq( + "App ID", "App Name", "Started", "Completed", @@ -81,7 +82,8 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { val duration = UIUtils.formatDuration(info.endTime - info.startTime) val lastUpdated = UIUtils.formatDate(info.lastUpdated) - + + From 1147973f1c7713013c7c0ca414482b511a730475 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 17 Sep 2014 16:52:27 -0700 Subject: [PATCH 025/315] [SPARK-3567] appId field in SparkDeploySchedulerBackend should be volatile Author: Kousuke Saruta Closes #2428 from sarutak/appid-volatile-modification and squashes the following commits: c7d890d [Kousuke Saruta] Added volatile modifier to appId field in SparkDeploySchedulerBackend --- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 2f45d192e1d4d..5c5ecc8434d78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -34,7 +34,7 @@ private[spark] class SparkDeploySchedulerBackend( var client: AppClient = null var stopping = false var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ - var appId: String = _ + @volatile var appId: String = _ val registrationLock = new Object() var registrationDone = false From 3f169bfe3c322bf4344e13276dbbe34279b59ad0 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Wed, 17 Sep 2014 21:59:23 -0700 Subject: [PATCH 026/315] [SPARK-3565]Fix configuration item not consistent with document https://issues.apache.org/jira/browse/SPARK-3565 "spark.ports.maxRetries" should be "spark.port.maxRetries". Make the configuration keys in document and code consistent. Author: WangTaoTheTonic Closes #2427 from WangTaoTheTonic/fixPortRetries and squashes the following commits: c178813 [WangTaoTheTonic] Use blank lines trigger Jenkins 646f3fe [WangTaoTheTonic] also in SparkBuild.scala 3700dba [WangTaoTheTonic] Fix configuration item not consistent with document --- core/src/main/scala/org/apache/spark/util/Utils.scala | 6 +++--- .../scala/org/apache/spark/deploy/JsonProtocolSuite.scala | 2 ++ docs/configuration.md | 2 +- project/SparkBuild.scala | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c76b7af18481d..ed063844323af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1382,15 +1382,15 @@ private[spark] object Utils extends Logging { } /** - * Default number of retries in binding to a port. + * Default maximum number of retries when binding to a port before giving up. */ val portMaxRetries: Int = { if (sys.props.contains("spark.testing")) { // Set a higher number of retries for tests... - sys.props.get("spark.ports.maxRetries").map(_.toInt).getOrElse(100) + sys.props.get("spark.port.maxRetries").map(_.toInt).getOrElse(100) } else { Option(SparkEnv.get) - .flatMap(_.conf.getOption("spark.ports.maxRetries")) + .flatMap(_.conf.getOption("spark.port.maxRetries")) .map(_.toInt) .getOrElse(16) } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 2a58c6a40d8e4..3f1cd0752e766 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -115,11 +115,13 @@ class JsonProtocolSuite extends FunSuite { workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis workerInfo } + def createExecutorRunner(): ExecutorRunner = { new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", new File("sparkHome"), new File("workDir"), "akka://worker", new SparkConf, ExecutorState.RUNNING) } + def createDriverRunner(): DriverRunner = { new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(), null, "akka://worker") diff --git a/docs/configuration.md b/docs/configuration.md index 99faf51c6f3db..a6dd7245e1552 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -657,7 +657,7 @@ Apart from these, the following properties are also available, and may be useful diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ab9f8ba120e83..12ac82293df76 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -336,7 +336,7 @@ object TestSettings { fork := true, javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", - javaOptions in Test += "-Dspark.ports.maxRetries=100", + javaOptions in Test += "-Dspark.port.maxRetries=100", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") From 5547fa1ee98bf166061804bd64df4cb51a656a3f Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 17 Sep 2014 22:37:11 -0700 Subject: [PATCH 027/315] [SPARK-3534] Add hive-thriftserver to SQL tests Addresses the problem pointed out in [this comment](https://github.com/apache/spark/pull/2441#issuecomment-55990116). Author: Nicholas Chammas Closes #2442 from nchammas/patch-1 and squashes the following commits: 7e68b60 [Nicholas Chammas] [SPARK-3534] Add hive-thriftserver to SQL tests --- dev/run-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests b/dev/run-tests index 7c002160c3a4a..5f6df17b509a3 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -149,7 +149,7 @@ echo "=========================================================================" if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string #+ will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test") else SBT_MAVEN_TEST_ARGS=("test") fi From 6772afec2f57360bd886ba3c8487e6140869d8f0 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Wed, 17 Sep 2014 22:54:34 -0700 Subject: [PATCH 028/315] [Minor] rat exclude dependency-reduced-pom.xml Author: GuoQiang Li Closes #2326 from witgo/rat-excludes and squashes the following commits: 860904e [GuoQiang Li] rat exclude dependency-reduced-pom.xml --- .rat-excludes | 1 + 1 file changed, 1 insertion(+) diff --git a/.rat-excludes b/.rat-excludes index fb6323daf9211..1897ec8f747ca 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -58,3 +58,4 @@ dist/* .*iws logs .*scalastyle-output.xml +.*dependency-reduced-pom.xml From 3447d100900af15a7340a2f6a5430ffb6d9c6c23 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 18 Sep 2014 10:17:18 -0700 Subject: [PATCH 029/315] [SPARK-3547]Using a special exit code instead of 1 to represent ClassNotFoundExcepti... ...on As improvement of https://github.com/apache/spark/pull/1944, we should use more special exit code to represent ClassNotFoundException. Author: WangTaoTheTonic Closes #2421 from WangTaoTheTonic/classnotfoundExitCode and squashes the following commits: 645a22a [WangTaoTheTonic] Serveral typos to trigger Jenkins d6ae559 [WangTaoTheTonic] use 101 instead a2d6465 [WangTaoTheTonic] use 127 instead fbb232f [WangTaoTheTonic] Using a special exit code instead of 1 to represent ClassNotFoundException --- bin/spark-sql | 2 +- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../main/scala/org/apache/spark/network/nio/Connection.scala | 2 +- .../org/apache/spark/network/nio/ConnectionManager.scala | 4 ++-- sbin/start-thriftserver.sh | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/bin/spark-sql b/bin/spark-sql index ae096530cad04..9d66140b6aa17 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -24,7 +24,7 @@ set -o posix CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" -CLASS_NOT_FOUND_EXIT_STATUS=1 +CLASS_NOT_FOUND_EXIT_STATUS=101 # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5ed3575816a38..5d15af1326ef0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -54,7 +54,7 @@ object SparkSubmit { private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" - private val CLASS_NOT_FOUND_EXIT_STATUS = 1 + private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // Exposed for testing private[spark] var exitFn: () => Unit = () => System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 74074a8dcbfff..18172d359cb35 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -460,7 +460,7 @@ private[spark] class ReceivingConnection( if (currId != null) currId else super.getRemoteConnectionManagerId() } - // The reciever's remote address is the local socket on remote side : which is NOT + // The receiver's remote address is the local socket on remote side : which is NOT // the connection manager id of the receiver. // We infer that from the messages we receive on the receiver socket. private def processConnectionManagerId(header: MessageChunkHeader) { diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 09d3ea306515b..5aa7e94943561 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -501,7 +501,7 @@ private[nio] class ConnectionManager( def changeConnectionKeyInterest(connection: Connection, ops: Int) { keyInterestChangeRequests += ((connection.key, ops)) - // so that registerations happen ! + // so that registrations happen ! wakeupSelector() } @@ -832,7 +832,7 @@ private[nio] class ConnectionManager( } /** - * Send a message and block until an acknowldgment is received or an error occurs. + * Send a message and block until an acknowledgment is received or an error occurs. * @param connectionManagerId the message's destination * @param message the message being sent * @return a Future that either returns the acknowledgment message or captures an exception. diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 4ce40fe750384..ba953e763faab 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -27,7 +27,7 @@ set -o posix FWDIR="$(cd "`dirname "$0"`"/..; pwd)" CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" -CLASS_NOT_FOUND_EXIT_STATUS=1 +CLASS_NOT_FOUND_EXIT_STATUS=101 function usage { echo "Usage: ./sbin/start-thriftserver [options] [thrift server options]" From 3ad4176cf980591469997a8a612bf422c90f86fd Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 18 Sep 2014 10:30:17 -0700 Subject: [PATCH 030/315] SPARK-3579 Jekyll doc generation is different across environments. This patch makes some small changes to fix this problem: 1. We document specific versions of Jekyll/Kramdown to use that match those used when building the upstream docs. 2. We add a configuration for a property that for some reason varies across packages of Jekyll/Kramdown even with the same version. Author: Patrick Wendell Closes #2443 from pwendell/jekyll and squashes the following commits: 54ee2ab [Patrick Wendell] SPARK-3579 Jekyll doc generation is different across environments. --- docs/README.md | 16 ++++++++++------ docs/_config.yml | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/README.md b/docs/README.md index fdc89d2eb767a..79708c3df9106 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,12 +20,16 @@ In this directory you will find textfiles formatted using Markdown, with an ".md read those text files directly if you want. Start with index.md. The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com). -To use the `jekyll` command, you will need to have Jekyll installed. -The easiest way to do this is via a Ruby Gem, see the -[jekyll installation instructions](http://jekyllrb.com/docs/installation). -If not already installed, you need to install `kramdown` and `jekyll-redirect-from` Gems -with `sudo gem install kramdown jekyll-redirect-from`. -Execute `jekyll build` from the `docs/` directory. Compiling the site with Jekyll will create a directory +`Jekyll` and a few dependencies must be installed for this to work. We recommend +installing via the Ruby Gem dependency manager. Since the exact HTML output +varies between versions of Jekyll and its dependencies, we list specific versions here +in some cases: + + $ sudo gem install jekyll -v 1.4.3 + $ sudo gem uninstall kramdown -v 1.4.1 + $ sudo gem install jekyll-redirect-from + +Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory called `_site` containing index.html as well as the rest of the compiled files. You can modify the default Jekyll build as follows: diff --git a/docs/_config.yml b/docs/_config.yml index d3ea2625c7448..7bc3a78e2d265 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -3,6 +3,11 @@ markdown: kramdown gems: - jekyll-redirect-from +# For some reason kramdown seems to behave differently on different +# OS/packages wrt encoding. So we hard code this config. +kramdown: + entity_output: numeric + # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. SPARK_VERSION: 1.0.0-SNAPSHOT From 6cab838b9803e3294c07bbf731c47154ec57afc0 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 18 Sep 2014 12:04:32 -0700 Subject: [PATCH 031/315] [SPARK-3566] [BUILD] .gitignore and .rat-excludes should consider Windows cmd file and Emacs' backup files Author: Kousuke Saruta Closes #2426 from sarutak/emacs-metafiles-ignore and squashes the following commits: a306020 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into emacs-metafiles-ignore 6a0a5eb [Kousuke Saruta] Added cmd file entry to .rat-excludes and .gitignore 897da63 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into emacs-metafiles-ignore 8cade06 [Kousuke Saruta] Modified .gitignore to ignore emacs lock file and backup file --- .gitignore | 3 +++ .rat-excludes | 1 + 2 files changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index a31bf7e0091f4..1bcd0165761ac 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ *~ +*.#* +*#*# *.swp *.ipr *.iml @@ -16,6 +18,7 @@ third_party/libmesos.so third_party/libmesos.dylib conf/java-opts conf/*.sh +conf/*.cmd conf/*.properties conf/*.conf conf/*.xml diff --git a/.rat-excludes b/.rat-excludes index 1897ec8f747ca..9fc99d7fca35d 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -20,6 +20,7 @@ log4j.properties.template metrics.properties.template slaves spark-env.sh +spark-env.cmd spark-env.sh.template log4j-defaults.properties bootstrap-tooltip.js From 471e6a3a47bd4b94878798f6f6fc93e2e672efff Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 18 Sep 2014 12:07:24 -0700 Subject: [PATCH 032/315] [SPARK-3589][Minor]remove redundant code https://issues.apache.org/jira/browse/SPARK-3589 "export CLASSPATH" in spark-class is redundant since same variable is exported before. We could reuse defined value "isYarnCluster" in SparkSubmit.scala. Author: WangTaoTheTonic Closes #2445 from WangTaoTheTonic/removeRedundant and squashes the following commits: 6fb6872 [WangTaoTheTonic] remove redundant code --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5d15af1326ef0..3dd1dd5b82fe8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -261,7 +261,7 @@ object SparkSubmit { } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class - if (clusterManager == YARN && deployMode == CLUSTER) { + if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.primaryResource != SPARK_INTERNAL) { childArgs += ("--jar", args.primaryResource) From b3ed37e5bad15d56db90c2b25fe11c1f758d3a97 Mon Sep 17 00:00:00 2001 From: Victsm Date: Thu, 18 Sep 2014 15:58:14 -0700 Subject: [PATCH 033/315] [SPARK-3560] Fixed setting spark.jars system property in yarn-cluster mode Author: Victsm Author: Min Shen Closes #2449 from Victsm/SPARK-3560 and squashes the following commits: 918405a [Victsm] Removed the additional space 4502a2a [Min Shen] [SPARK-3560] Fixed setting spark.jars system property in yarn-cluster mode. (cherry picked from commit 832dff64ddb1240a4c8e22fcdc0e993cc8c808de) Signed-off-by: Andrew Or --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3dd1dd5b82fe8..ec0324e24915a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -172,7 +172,7 @@ object SparkSubmit { // All cluster managers OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), - OptionAssigner(args.jars, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, @@ -205,6 +205,7 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options + OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 22b369a829418..0c324d8bdf6a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -154,6 +154,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { sysProps("spark.app.name") should be ("beauty") sysProps("spark.shuffle.spill") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") + sysProps.keys should not contain ("spark.jars") } test("handles YARN client mode") { From 9306297d1d888d0430f79b2133ee7377871a3a18 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 18 Sep 2014 17:49:28 -0700 Subject: [PATCH 034/315] [Minor Hot Fix] Move a line in SparkSubmit to the right place This was introduced in #2449 Author: Andrew Or Closes #2452 from andrewor14/standalone-hot-fix and squashes the following commits: d5190ca [Andrew Or] Put that line in the right place --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index ec0324e24915a..d132ecb3f9989 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -183,6 +183,7 @@ object SparkSubmit { sysProp = "spark.driver.extraLibraryPath"), // Standalone cluster only + OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), @@ -205,7 +206,6 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options - OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, From e77fa81a61798c89d5a9b6c9dc067d11785254b7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Sep 2014 18:11:48 -0700 Subject: [PATCH 035/315] [SPARK-3554] [PySpark] use broadcast automatically for large closure Py4j can not handle large string efficiently, so we should use broadcast for large closure automatically. (Broadcast use local filesystem to pass through data). Author: Davies Liu Closes #2417 from davies/command and squashes the following commits: fbf4e97 [Davies Liu] bugfix aefd508 [Davies Liu] use broadcast automatically for large closure --- python/pyspark/rdd.py | 4 ++++ python/pyspark/sql.py | 8 ++++++-- python/pyspark/tests.py | 6 ++++++ python/pyspark/worker.py | 4 +++- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cb09c191bed71..b43606b7304c5 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2061,8 +2061,12 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) + # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 8f6dbab240c7b..42a9920f10e6f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -27,7 +27,7 @@ from array import array from operator import itemgetter -from pyspark.rdd import RDD, PipelinedRDD +from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -975,7 +975,11 @@ def registerFunction(self, name, f, returnType=StringType()): command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) + if pickled_command > (1 << 20): # 1M + broadcast = self._sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self._sc._pickled_broadcast_vars], self._sc._gateway._gateway_client) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 0b3854347ad2e..7301966e48045 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -434,6 +434,12 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_large_closure(self): + N = 1000000 + data = [float(i) for i in xrange(N)] + m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum() + self.assertEquals(N, m) + def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) b = self.sc.parallelize(range(100, 105)) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 252176ac65fec..d6c06e2dbef62 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -77,10 +77,12 @@ def main(infile, outfile): _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 - _broadcastRegistry.remove(bid) + _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) + if isinstance(command, Broadcast): + command = pickleSer.loads(command.value) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) From e76ef5cb8eed6b78fb722b3d6fbeb9466a0e3499 Mon Sep 17 00:00:00 2001 From: Burak Date: Thu, 18 Sep 2014 22:18:51 -0700 Subject: [PATCH 036/315] [SPARK-3418] Sparse Matrix support (CCS) and additional native BLAS operations added Local `SparseMatrix` support added in Compressed Column Storage (CCS) format in addition to Level-2 and Level-3 BLAS operations such as dgemv and dgemm respectively. BLAS doesn't support sparse matrix operations, therefore support for `SparseMatrix`-`DenseMatrix` multiplication and `SparseMatrix`-`DenseVector` implementations have been added. I will post performance comparisons in the comments momentarily. Author: Burak Closes #2294 from brkyvz/SPARK-3418 and squashes the following commits: 88814ed [Burak] Hopefully fixed MiMa this time 47e49d5 [Burak] really fixed MiMa issue f0bae57 [Burak] [SPARK-3418] Fixed MiMa compatibility issues (excluded from check) 4b7dbec [Burak] 9/17 comments addressed 7af2f83 [Burak] sealed traits Vector and Matrix d3a8a16 [Burak] [SPARK-3418] Squashed missing alpha bug. 421045f [Burak] [SPARK-3418] New code review comments addressed f35a161 [Burak] [SPARK-3418] Code review comments addressed and multiplication further optimized 2508577 [Burak] [SPARK-3418] Fixed one more style issue d16e8a0 [Burak] [SPARK-3418] Fixed style issues and added documentation for methods 204a3f7 [Burak] [SPARK-3418] Fixed failing Matrix unit test 6025297 [Burak] [SPARK-3418] Fixed Scala-style errors dc7be71 [Burak] [SPARK-3418][MLlib] Matrix unit tests expanded with indexing and updating d2d5851 [Burak] [SPARK-3418][MLlib] Sparse Matrix support and additional native BLAS operations added --- .../org/apache/spark/mllib/linalg/BLAS.scala | 330 +++++++++++++++++- .../apache/spark/mllib/linalg/Matrices.scala | 232 +++++++++++- .../apache/spark/mllib/linalg/Vectors.scala | 2 +- .../apache/spark/mllib/linalg/BLASSuite.scala | 111 ++++++ .../linalg/BreezeMatrixConversionSuite.scala | 24 +- .../spark/mllib/linalg/MatricesSuite.scala | 76 ++++ .../spark/mllib/util/TestingUtils.scala | 65 +++- project/MimaExcludes.scala | 4 +- 8 files changed, 834 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 70e23033c8754..54ee930d61003 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -18,13 +18,17 @@ package org.apache.spark.mllib.linalg import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +import org.apache.spark.Logging /** * BLAS routines for MLlib's vectors and matrices. */ -private[mllib] object BLAS extends Serializable { +private[mllib] object BLAS extends Serializable with Logging { @transient private var _f2jBLAS: NetlibBLAS = _ + @transient private var _nativeBLAS: NetlibBLAS = _ // For level-1 routines, we use Java implementation. private def f2jBLAS: NetlibBLAS = { @@ -197,4 +201,328 @@ private[mllib] object BLAS extends Serializable { throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") } } + + // For level-3 routines, we use the native BLAS. + private def nativeBLAS: NetlibBLAS = { + if (_nativeBLAS == null) { + _nativeBLAS = NativeBLAS + } + _nativeBLAS + } + + /** + * C := alpha * A * B + beta * C + * @param transA whether to use the transpose of matrix A (true), or A itself (false). + * @param transB whether to use the transpose of matrix B (true), or B itself (false). + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * @param beta a scalar that can be used to scale matrix C. + * @param C the resulting matrix C. Size of m x n. + */ + def gemm( + transA: Boolean, + transB: Boolean, + alpha: Double, + A: Matrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + if (alpha == 0.0) { + logDebug("gemm: alpha is equal to 0. Returning C.") + } else { + A match { + case sparse: SparseMatrix => + gemm(transA, transB, alpha, sparse, B, beta, C) + case dense: DenseMatrix => + gemm(transA, transB, alpha, dense, B, beta, C) + case _ => + throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") + } + } + } + + /** + * C := alpha * A * B + beta * C + * + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * @param beta a scalar that can be used to scale matrix C. + * @param C the resulting matrix C. Size of m x n. + */ + def gemm( + alpha: Double, + A: Matrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + gemm(false, false, alpha, A, B, beta, C) + } + + /** + * C := alpha * A * B + beta * C + * For `DenseMatrix` A. + */ + private def gemm( + transA: Boolean, + transB: Boolean, + alpha: Double, + A: DenseMatrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + val mA: Int = if (!transA) A.numRows else A.numCols + val nB: Int = if (!transB) B.numCols else B.numRows + val kA: Int = if (!transA) A.numCols else A.numRows + val kB: Int = if (!transB) B.numRows else B.numCols + val tAstr = if (!transA) "N" else "T" + val tBstr = if (!transB) "N" else "T" + + require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") + require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") + require(nB == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") + + nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, + beta, C.values, C.numRows) + } + + /** + * C := alpha * A * B + beta * C + * For `SparseMatrix` A. + */ + private def gemm( + transA: Boolean, + transB: Boolean, + alpha: Double, + A: SparseMatrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix): Unit = { + val mA: Int = if (!transA) A.numRows else A.numCols + val nB: Int = if (!transB) B.numCols else B.numRows + val kA: Int = if (!transA) A.numCols else A.numRows + val kB: Int = if (!transB) B.numRows else B.numCols + + require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") + require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") + require(nB == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") + + val Avals = A.values + val Arows = if (!transA) A.rowIndices else A.colPtrs + val Acols = if (!transA) A.colPtrs else A.rowIndices + + // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices + if (transA){ + var colCounterForB = 0 + if (!transB) { // Expensive to put the check inside the loop + while (colCounterForB < nB) { + var rowCounterForA = 0 + val Cstart = colCounterForB * mA + val Bstart = colCounterForB * kA + while (rowCounterForA < mA) { + var i = Arows(rowCounterForA) + val indEnd = Arows(rowCounterForA + 1) + var sum = 0.0 + while (i < indEnd) { + sum += Avals(i) * B.values(Bstart + Acols(i)) + i += 1 + } + val Cindex = Cstart + rowCounterForA + C.values(Cindex) = beta * C.values(Cindex) + sum * alpha + rowCounterForA += 1 + } + colCounterForB += 1 + } + } else { + while (colCounterForB < nB) { + var rowCounter = 0 + val Cstart = colCounterForB * mA + while (rowCounter < mA) { + var i = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var sum = 0.0 + while (i < indEnd) { + sum += Avals(i) * B(colCounterForB, Acols(i)) + i += 1 + } + val Cindex = Cstart + rowCounter + C.values(Cindex) = beta * C.values(Cindex) + sum * alpha + rowCounter += 1 + } + colCounterForB += 1 + } + } + } else { + // Scale matrix first if `beta` is not equal to 0.0 + if (beta != 0.0){ + f2jBLAS.dscal(C.values.length, beta, C.values, 1) + } + // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of + // B, and added to C. + var colCounterForB = 0 // the column to be updated in C + if (!transB) { // Expensive to put the check inside the loop + while (colCounterForB < nB) { + var colCounterForA = 0 // The column of A to multiply with the row of B + val Bstart = colCounterForB * kB + val Cstart = colCounterForB * mA + while (colCounterForA < kA) { + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + val Bval = B.values(Bstart + colCounterForA) * alpha + while (i < indEnd){ + C.values(Cstart + Arows(i)) += Avals(i) * Bval + i += 1 + } + colCounterForA += 1 + } + colCounterForB += 1 + } + } else { + while (colCounterForB < nB) { + var colCounterForA = 0 // The column of A to multiply with the row of B + val Cstart = colCounterForB * mA + while (colCounterForA < kA){ + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + val Bval = B(colCounterForB, colCounterForA) * alpha + while (i < indEnd){ + C.values(Cstart + Arows(i)) += Avals(i) * Bval + i += 1 + } + colCounterForA += 1 + } + colCounterForB += 1 + } + } + } + } + + /** + * y := alpha * A * x + beta * y + * @param trans whether to use the transpose of matrix A (true), or A itself (false). + * @param alpha a scalar to scale the multiplication A * x. + * @param A the matrix A that will be left multiplied to x. Size of m x n. + * @param x the vector x that will be left multiplied by A. Size of n x 1. + * @param beta a scalar that can be used to scale vector y. + * @param y the resulting vector y. Size of m x 1. + */ + def gemv( + trans: Boolean, + alpha: Double, + A: Matrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + + val mA: Int = if (!trans) A.numRows else A.numCols + val nx: Int = x.size + val nA: Int = if (!trans) A.numCols else A.numRows + + require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") + require(mA == y.size, + s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + if (alpha == 0.0) { + logDebug("gemv: alpha is equal to 0. Returning y.") + } else { + A match { + case sparse: SparseMatrix => + gemv(trans, alpha, sparse, x, beta, y) + case dense: DenseMatrix => + gemv(trans, alpha, dense, x, beta, y) + case _ => + throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") + } + } + } + + /** + * y := alpha * A * x + beta * y + * + * @param alpha a scalar to scale the multiplication A * x. + * @param A the matrix A that will be left multiplied to x. Size of m x n. + * @param x the vector x that will be left multiplied by A. Size of n x 1. + * @param beta a scalar that can be used to scale vector y. + * @param y the resulting vector y. Size of m x 1. + */ + def gemv( + alpha: Double, + A: Matrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + gemv(false, alpha, A, x, beta, y) + } + + /** + * y := alpha * A * x + beta * y + * For `DenseMatrix` A. + */ + private def gemv( + trans: Boolean, + alpha: Double, + A: DenseMatrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + val tStrA = if (!trans) "N" else "T" + nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + y.values, 1) + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A. + */ + private def gemv( + trans: Boolean, + alpha: Double, + A: SparseMatrix, + x: DenseVector, + beta: Double, + y: DenseVector): Unit = { + + val mA: Int = if(!trans) A.numRows else A.numCols + val nA: Int = if(!trans) A.numCols else A.numRows + + val Avals = A.values + val Arows = if (!trans) A.rowIndices else A.colPtrs + val Acols = if (!trans) A.colPtrs else A.rowIndices + + // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices + if (trans){ + var rowCounter = 0 + while (rowCounter < mA){ + var i = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var sum = 0.0 + while(i < indEnd){ + sum += Avals(i) * x.values(Acols(i)) + i += 1 + } + y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha + rowCounter += 1 + } + } else { + // Scale vector first if `beta` is not equal to 0.0 + if (beta != 0.0){ + scal(beta, y) + } + // Perform matrix-vector multiplication and add to y + var colCounterForA = 0 + while (colCounterForA < nA){ + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + val xVal = x.values(colCounterForA) * alpha + while (i < indEnd){ + val rowIndex = Arows(i) + y.values(rowIndex) += Avals(i) * xVal + i += 1 + } + colCounterForA += 1 + } + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index b11ba5d30fbd3..5711532abcf80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,12 +17,16 @@ package org.apache.spark.mllib.linalg -import breeze.linalg.{Matrix => BM, DenseMatrix => BDM} +import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} + +import org.apache.spark.util.random.XORShiftRandom + +import java.util.Arrays /** * Trait for a local matrix. */ -trait Matrix extends Serializable { +sealed trait Matrix extends Serializable { /** Number of rows. */ def numRows: Int @@ -37,8 +41,46 @@ trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ - private[mllib] def apply(i: Int, j: Int): Double = toBreeze(i, j) + private[mllib] def apply(i: Int, j: Int): Double + + /** Return the index for the (i, j)-th element in the backing array. */ + private[mllib] def index(i: Int, j: Int): Int + + /** Update element at (i, j) */ + private[mllib] def update(i: Int, j: Int, v: Double): Unit + + /** Get a deep copy of the matrix. */ + def copy: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ + def multiply(y: DenseMatrix): DenseMatrix = { + val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] + BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + C + } + + /** Convenience method for `Matrix`-`DenseVector` multiplication. */ + def multiply(y: DenseVector): DenseVector = { + val output = new DenseVector(new Array[Double](numRows)) + BLAS.gemv(1.0, this, y, 0.0, output) + output + } + + /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ + def transposeMultiply(y: DenseMatrix): DenseMatrix = { + val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] + BLAS.gemm(true, false, 1.0, this, y, 0.0, C) + C + } + + /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ + def transposeMultiply(y: DenseVector): DenseVector = { + val output = new DenseVector(new Array[Double](numCols)) + BLAS.gemv(true, 1.0, this, y, 0.0, output) + output + } + + /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() } @@ -59,11 +101,98 @@ trait Matrix extends Serializable { */ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { - require(values.length == numRows * numCols) + require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + + private[mllib] def apply(i: Int): Double = values(i) + + private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) + + private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + + private[mllib] def update(i: Int, j: Int, v: Double): Unit = { + values(index(i, j)) = v + } + + override def copy = new DenseMatrix(numRows, numCols, values.clone()) +} + +/** + * Column-majored sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing order for each + * column + * @param values non-zero matrix entries in column major + */ +class SparseMatrix( + val numRows: Int, + val numCols: Int, + val colPtrs: Array[Int], + val rowIndices: Array[Int], + val values: Array[Double]) extends Matrix { + + require(values.length == rowIndices.length, "The number of row indices and values don't match! " + + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") + require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + + s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + + s"numCols: $numCols") + + override def toArray: Array[Double] = { + val arr = new Array[Double](numRows * numCols) + var j = 0 + while (j < numCols) { + var i = colPtrs(j) + val indEnd = colPtrs(j + 1) + val offset = j * numRows + while (i < indEnd) { + val rowIndex = rowIndices(i) + arr(offset + rowIndex) = values(i) + i += 1 + } + j += 1 + } + arr + } + + private[mllib] def toBreeze: BM[Double] = + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + + private[mllib] def apply(i: Int, j: Int): Double = { + val ind = index(i, j) + if (ind < 0) 0.0 else values(ind) + } + + private[mllib] def index(i: Int, j: Int): Int = { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } + + private[mllib] def update(i: Int, j: Int, v: Double): Unit = { + val ind = index(i, j) + if (ind == -1){ + throw new NoSuchElementException("The given row and column indices correspond to a zero " + + "value. Only non-zero elements in Sparse Matrices can be updated.") + } else { + values(index(i, j)) = v + } + } + + override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } /** @@ -82,6 +211,24 @@ object Matrices { new DenseMatrix(numRows, numCols, values) } + /** + * Creates a column-majored sparse matrix in Compressed Sparse Column (CSC) format. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry + * @param values non-zero matrix entries in column major + */ + def sparse( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]): Matrix = { + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + } + /** * Creates a Matrix instance from a breeze matrix. * @param breeze a breeze matrix @@ -93,9 +240,84 @@ object Matrices { require(dm.majorStride == dm.rows, "Do not support stride size different from the number of rows.") new DenseMatrix(dm.rows, dm.cols, dm.data) + case sm: BSM[Double] => + new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( s"Do not support conversion from type ${breeze.getClass.getName}.") } } + + /** + * Generate a `DenseMatrix` consisting of zeros. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros + */ + def zeros(numRows: Int, numCols: Int): Matrix = + new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + + /** + * Generate a `DenseMatrix` consisting of ones. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones + */ + def ones(numRows: Int, numCols: Int): Matrix = + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + + /** + * Generate an Identity Matrix in `DenseMatrix` format. + * @param n number of rows and columns of the matrix + * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal + */ + def eye(n: Int): Matrix = { + val identity = Matrices.zeros(n, n) + var i = 0 + while (i < n){ + identity.update(i, i, 1.0) + i += 1 + } + identity + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) + */ + def rand(numRows: Int, numCols: Int): Matrix = { + val rand = new XORShiftRandom + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble())) + } + + /** + * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. + * @param numRows number of rows of the matrix + * @param numCols number of columns of the matrix + * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) + */ + def randn(numRows: Int, numCols: Int): Matrix = { + val rand = new XORShiftRandom + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian())) + } + + /** + * Generate a diagonal matrix in `DenseMatrix` format from the supplied values. + * @param vector a `Vector` tat will form the values on the diagonal of the matrix + * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` + * on the diagonal + */ + def diag(vector: Vector): Matrix = { + val n = vector.size + val matrix = Matrices.eye(n) + val values = vector.toArray + var i = 0 + while (i < n) { + matrix.update(i, i, values(i)) + i += 1 + } + matrix + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index a45781d12e41e..6af225b7f49f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -33,7 +33,7 @@ import org.apache.spark.SparkException * * Note: Users should not implement this interface. */ -trait Vector extends Serializable { +sealed trait Vector extends Serializable { /** * Size of the vector. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 1952e6734ecf7..5d70c914f14b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -126,4 +126,115 @@ class BLASSuite extends FunSuite { } } } + + test("gemm") { + + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) + val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + + assert(dA multiply B ~== expected absTol 1e-15) + assert(sA multiply B ~== expected absTol 1e-15) + + val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) + val C2 = C1.copy + val C3 = C1.copy + val C4 = C1.copy + val C5 = C1.copy + val C6 = C1.copy + val C7 = C1.copy + val C8 = C1.copy + val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) + val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) + + gemm(1.0, dA, B, 2.0, C1) + gemm(1.0, sA, B, 2.0, C2) + gemm(2.0, dA, B, 2.0, C3) + gemm(2.0, sA, B, 2.0, C4) + assert(C1 ~== expected2 absTol 1e-15) + assert(C2 ~== expected2 absTol 1e-15) + assert(C3 ~== expected3 absTol 1e-15) + assert(C4 ~== expected3 absTol 1e-15) + + withClue("columns of A don't match the rows of B") { + intercept[Exception] { + gemm(true, false, 1.0, dA, B, 2.0, C1) + } + } + + val dAT = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sAT = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT transposeMultiply B ~== expected absTol 1e-15) + assert(sAT transposeMultiply B ~== expected absTol 1e-15) + + gemm(true, false, 1.0, dAT, B, 2.0, C5) + gemm(true, false, 1.0, sAT, B, 2.0, C6) + gemm(true, false, 2.0, dAT, B, 2.0, C7) + gemm(true, false, 2.0, sAT, B, 2.0, C8) + assert(C5 ~== expected2 absTol 1e-15) + assert(C6 ~== expected2 absTol 1e-15) + assert(C7 ~== expected3 absTol 1e-15) + assert(C8 ~== expected3 absTol 1e-15) + } + + test("gemv") { + + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val x = new DenseVector(Array(1.0, 2.0, 3.0)) + val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) + + assert(dA multiply x ~== expected absTol 1e-15) + assert(sA multiply x ~== expected absTol 1e-15) + + val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) + val y2 = y1.copy + val y3 = y1.copy + val y4 = y1.copy + val y5 = y1.copy + val y6 = y1.copy + val y7 = y1.copy + val y8 = y1.copy + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) + val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) + + gemv(1.0, dA, x, 2.0, y1) + gemv(1.0, sA, x, 2.0, y2) + gemv(2.0, dA, x, 2.0, y3) + gemv(2.0, sA, x, 2.0, y4) + assert(y1 ~== expected2 absTol 1e-15) + assert(y2 ~== expected2 absTol 1e-15) + assert(y3 ~== expected3 absTol 1e-15) + assert(y4 ~== expected3 absTol 1e-15) + withClue("columns of A don't match the rows of B") { + intercept[Exception] { + gemv(true, 1.0, dA, x, 2.0, y1) + } + } + + val dAT = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sAT = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT transposeMultiply x ~== expected absTol 1e-15) + assert(sAT transposeMultiply x ~== expected absTol 1e-15) + + gemv(true, 1.0, dAT, x, 2.0, y5) + gemv(true, 1.0, sAT, x, 2.0, y6) + gemv(true, 2.0, dAT, x, 2.0, y7) + gemv(true, 2.0, sAT, x, 2.0, y8) + assert(y5 ~== expected2 absTol 1e-15) + assert(y6 ~== expected2 absTol 1e-15) + assert(y7 ~== expected3 absTol 1e-15) + assert(y8 ~== expected3 absTol 1e-15) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 82d49c76ed02b..73a6d3a27d868 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import org.scalatest.FunSuite -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} class BreezeMatrixConversionSuite extends FunSuite { test("dense matrix to breeze") { @@ -37,4 +37,26 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") } + + test("sparse matrix to breeze") { + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values) + val breeze = mat.toBreeze.asInstanceOf[BSM[Double]] + assert(breeze.rows === mat.numRows) + assert(breeze.cols === mat.numCols) + assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data") + } + + test("sparse breeze matrix to sparse matrix") { + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val breeze = new BSM[Double](values, 3, 2, colPtrs, rowIndices) + val mat = Matrices.fromBreeze(breeze).asInstanceOf[SparseMatrix] + assert(mat.numRows === breeze.rows) + assert(mat.numCols === breeze.cols) + assert(mat.values.eq(breeze.data), "should not copy data") + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 9c66b4db9f16b..5f8b8c4b72697 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -36,4 +36,80 @@ class MatricesSuite extends FunSuite { Matrices.dense(3, 2, Array(0.0, 1.0, 2.0)) } } + + test("sparse matrix construction") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val mat = Matrices.sparse(m, n, colPtrs, rowIndices, values).asInstanceOf[SparseMatrix] + assert(mat.numRows === m) + assert(mat.numCols === n) + assert(mat.values.eq(values), "should not copy data") + assert(mat.colPtrs.eq(colPtrs), "should not copy data") + assert(mat.rowIndices.eq(rowIndices), "should not copy data") + } + + test("sparse matrix construction with wrong number of elements") { + intercept[IllegalArgumentException] { + Matrices.sparse(3, 2, Array(0, 1), Array(1, 2, 1), Array(0.0, 1.0, 2.0)) + } + + intercept[IllegalArgumentException] { + Matrices.sparse(3, 2, Array(0, 1, 2), Array(1, 2), Array(0.0, 1.0, 2.0)) + } + } + + test("matrix copies are deep copies") { + val m = 3 + val n = 2 + + val denseMat = Matrices.dense(m, n, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) + val denseCopy = denseMat.copy + + assert(!denseMat.toArray.eq(denseCopy.toArray)) + + val values = Array(1.0, 2.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 1, 2) + val sparseMat = Matrices.sparse(m, n, colPtrs, rowIndices, values) + val sparseCopy = sparseMat.copy + + assert(!sparseMat.toArray.eq(sparseCopy.toArray)) + } + + test("matrix indexing and updating") { + val m = 3 + val n = 2 + val allValues = Array(0.0, 1.0, 2.0, 3.0, 4.0, 0.0) + + val denseMat = new DenseMatrix(m, n, allValues) + + assert(denseMat(0, 1) === 3.0) + assert(denseMat(0, 1) === denseMat.values(3)) + assert(denseMat(0, 1) === denseMat(3)) + assert(denseMat(0, 0) === 0.0) + + denseMat.update(0, 0, 10.0) + assert(denseMat(0, 0) === 10.0) + assert(denseMat.values(0) === 10.0) + + val sparseValues = Array(1.0, 2.0, 3.0, 4.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(1, 2, 0, 1) + val sparseMat = new SparseMatrix(m, n, colPtrs, rowIndices, sparseValues) + + assert(sparseMat(0, 1) === 3.0) + assert(sparseMat(0, 1) === sparseMat.values(2)) + assert(sparseMat(0, 0) === 0.0) + + intercept[NoSuchElementException] { + sparseMat.update(0, 0, 10.0) + } + + sparseMat.update(0, 1, 10.0) + assert(sparseMat(0, 1) === 10.0) + assert(sparseMat.values(2) === 10.0) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 29cc42d8cbea7..30b906aaa3ba4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.scalatest.exceptions.TestFailedException object TestingUtils { @@ -169,4 +169,67 @@ object TestingUtils { override def toString = x.toString } + case class CompareMatrixRightSide( + fun: (Matrix, Matrix, Double) => Boolean, y: Matrix, eps: Double, method: String) + + /** + * Implicit class for comparing two matrices using relative tolerance or absolute tolerance. + */ + implicit class MatrixWithAlmostEquals(val x: Matrix) { + + /** + * When the difference of two vectors are within eps, returns true; otherwise, returns false. + */ + def ~=(r: CompareMatrixRightSide): Boolean = r.fun(x, r.y, r.eps) + + /** + * When the difference of two vectors are within eps, returns false; otherwise, returns true. + */ + def !~=(r: CompareMatrixRightSide): Boolean = !r.fun(x, r.y, r.eps) + + /** + * Throws exception when the difference of two vectors are NOT within eps; + * otherwise, returns true. + */ + def ~==(r: CompareMatrixRightSide): Boolean = { + if (!r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Expected \n$x\n and \n${r.y}\n to be within ${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Throws exception when the difference of two matrices are within eps; otherwise, returns true. + */ + def !~==(r: CompareMatrixRightSide): Boolean = { + if (r.fun(x, r.y, r.eps)) { + throw new TestFailedException( + s"Did not expect \n$x\n and \n${r.y}\n to be within " + + "${r.eps}${r.method} for all elements.", 0) + } + true + } + + /** + * Comparison using absolute tolerance. + */ + def absTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide( + (x: Matrix, y: Matrix, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps) + }, x, eps, ABS_TOL_MSG) + + /** + * Comparison using relative tolerance. Note that comparing against sparse vector + * with elements having value of zero will raise exception because it involves with + * comparing against zero. + */ + def relTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide( + (x: Matrix, y: Matrix, eps: Double) => { + x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps) + }, x, eps, REL_TOL_MSG) + + override def toString = x.toString + } + } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2f1e05dfcc7b1..3280e662fa0b1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,7 +37,9 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("graphx") - ) + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ + MimaBuild.excludeSparkClass("mllib.linalg.Vector") case v if v.startsWith("1.1") => Seq( From 3bbbdd8180cf316c6f8dde0e879410b6b29f8cc3 Mon Sep 17 00:00:00 2001 From: Larry Xiao Date: Thu, 18 Sep 2014 23:32:32 -0700 Subject: [PATCH 037/315] [SPARK-2062][GraphX] VertexRDD.apply does not use the mergeFunc VertexRDD.apply had a bug where it ignored the merge function for duplicate vertices and instead used whichever vertex attribute occurred first. This commit fixes the bug by passing the merge function through to ShippableVertexPartition.apply, which merges any duplicates using the merge function and then fills in missing vertices using the specified default vertex attribute. This commit also adds a unit test for VertexRDD.apply. Author: Larry Xiao Author: Blie Arkansol Author: Ankur Dave Closes #1903 from larryxiao/2062 and squashes the following commits: 625aa9d [Blie Arkansol] Merge pull request #1 from ankurdave/SPARK-2062 476770b [Ankur Dave] ShippableVertexPartition.initFrom: Don't run mergeFunc on default values 614059f [Larry Xiao] doc update: note about the default null value vertices construction dfdb3c9 [Larry Xiao] minor fix 1c70366 [Larry Xiao] scalastyle check: wrap line, parameter list indent 4 spaces e4ca697 [Larry Xiao] [TEST] VertexRDD.apply mergeFunc 6a35ea8 [Larry Xiao] [TEST] VertexRDD.apply mergeFunc 4fbc29c [Blie Arkansol] undo unnecessary change efae765 [Larry Xiao] fix mistakes: should be able to call with or without mergeFunc b2422f9 [Larry Xiao] Merge branch '2062' of github.com:larryxiao/spark into 2062 52dc7f7 [Larry Xiao] pass mergeFunc to VertexPartitionBase, where merge is handled 581e9ee [Larry Xiao] TODO: VertexRDDSuite 20d80a3 [Larry Xiao] [SPARK-2062][GraphX] VertexRDD.apply does not use the mergeFunc --- .../org/apache/spark/graphx/VertexRDD.scala | 4 +-- .../impl/ShippableVertexPartition.scala | 28 +++++++++++++++---- .../apache/spark/graphx/VertexRDDSuite.scala | 11 ++++++++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 04fbc9dbab8d1..2c8b245955d12 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -392,7 +392,7 @@ object VertexRDD { */ def apply[VD: ClassTag]( vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = { - VertexRDD(vertices, edges, defaultVal, (a, b) => b) + VertexRDD(vertices, edges, defaultVal, (a, b) => a) } /** @@ -419,7 +419,7 @@ object VertexRDD { (vertexIter, routingTableIter) => val routingTable = if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty - Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal)) + Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal, mergeFunc)) } new VertexRDD(vertexPartitions) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index dca54b8a7da86..5412d720475dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -36,7 +36,7 @@ private[graphx] object ShippableVertexPartition { /** Construct a `ShippableVertexPartition` from the given vertices without any routing table. */ def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)]): ShippableVertexPartition[VD] = - apply(iter, RoutingTablePartition.empty, null.asInstanceOf[VD]) + apply(iter, RoutingTablePartition.empty, null.asInstanceOf[VD], (a, b) => a) /** * Construct a `ShippableVertexPartition` from the given vertices with the specified routing @@ -44,10 +44,28 @@ object ShippableVertexPartition { */ def apply[VD: ClassTag]( iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD) - : ShippableVertexPartition[VD] = { - val fullIter = iter ++ routingTable.iterator.map(vid => (vid, defaultVal)) - val (index, values, mask) = VertexPartitionBase.initFrom(fullIter, (a: VD, b: VD) => a) - new ShippableVertexPartition(index, values, mask, routingTable) + : ShippableVertexPartition[VD] = + apply(iter, routingTable, defaultVal, (a, b) => a) + + /** + * Construct a `ShippableVertexPartition` from the given vertices with the specified routing + * table, filling in missing vertices mentioned in the routing table using `defaultVal`, + * and merging duplicate vertex atrribute with mergeFunc. + */ + def apply[VD: ClassTag]( + iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD, + mergeFunc: (VD, VD) => VD): ShippableVertexPartition[VD] = { + val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD] + // Merge the given vertices using mergeFunc + iter.foreach { pair => + map.setMerge(pair._1, pair._2, mergeFunc) + } + // Fill in missing vertices mentioned in the routing table + routingTable.iterator.foreach { vid => + map.changeValue(vid, defaultVal, identity) + } + + new ShippableVertexPartition(map.keySet, map._values, map.keySet.getBitSet, routingTable) } import scala.language.implicitConversions diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index cc86bafd2d644..42d3f21dbae98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -99,4 +99,15 @@ class VertexRDDSuite extends FunSuite with LocalSparkContext { } } + test("mergeFunc") { + // test to see if the mergeFunc is working correctly + withSpark { sc => + val verts = sc.parallelize(List((0L, 0), (1L, 1), (1L, 2), (2L, 3), (2L, 3), (2L, 3))) + val edges = EdgeRDD.fromEdges(sc.parallelize(List.empty[Edge[Int]])) + val rdd = VertexRDD(verts, edges, 0, (a: Int, b: Int) => a + b) + // test merge function + assert(rdd.collect.toSet == Set((0L, 0), (1L, 3), (2L, 9))) + } + } + } From a48956f5825d2255736eee50de79fba79bcb7e39 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 19 Sep 2014 10:49:42 -0700 Subject: [PATCH 038/315] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #726 (close requested by 'pwendell') Closes #151 (close requested by 'pwendell') From be0c7563ea001a59469dbba219d2a8ef5785afa3 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 19 Sep 2014 14:31:50 -0700 Subject: [PATCH 039/315] [SPARK-1701] Clarify slice vs partition in the programming guide This is a partial solution to SPARK-1701, only addressing the documentation confusion. Additional work can be to actually change the numSlices parameter name across languages, with care required for scala & python to maintain backward compatibility for named parameters. Author: Matthew Farrellee Closes #2305 from mattf/SPARK-1701 and squashes the following commits: c0af05d [Matthew Farrellee] Further tweak 06f80fc [Matthew Farrellee] Wording tweak from Josh Rosen's review 7b045e0 [Matthew Farrellee] [SPARK-1701] Clarify slice vs partition in the programming guide --- docs/programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 624cc744dfd51..01d378af574b5 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -286,7 +286,7 @@ We describe operations on distributed datasets later on. -One important parameter for parallel collections is the number of *slices* to cut the dataset into. Spark will run one task for each slice of the cluster. Typically you want 2-4 slices for each CPU in your cluster. Normally, Spark tries to set the number of slices automatically based on your cluster. However, you can also set it manually by passing it as a second parameter to `parallelize` (e.g. `sc.parallelize(data, 10)`). +One important parameter for parallel collections is the number of *partitions* to cut the dataset into. Spark will run one task for each partition of the cluster. Typically you want 2-4 partitions for each CPU in your cluster. Normally, Spark tries to set the number of partitions automatically based on your cluster. However, you can also set it manually by passing it as a second parameter to `parallelize` (e.g. `sc.parallelize(data, 10)`). Note: some places in the code use the term slices (a synonym for partitions) to maintain backward compatibility. ## External Datasets @@ -311,7 +311,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of slices of the file. By default, Spark creates one slice for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of slices by passing a larger value. Note that you cannot have fewer slices than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Scala API also supports several other data formats: @@ -343,7 +343,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of slices of the file. By default, Spark creates one slice for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of slices by passing a larger value. Note that you cannot have fewer slices than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Java API also supports several other data formats: @@ -375,7 +375,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of slices of the file. By default, Spark creates one slice for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of slices by passing a larger value. Note that you cannot have fewer slices than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Python API also supports several other data formats: From a03e5b81e91d9d792b6a2e01d1505394ea303dd8 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 19 Sep 2014 14:35:22 -0700 Subject: [PATCH 040/315] [SPARK-1701] [PySpark] remove slice terminology from python examples Author: Matthew Farrellee Closes #2304 from mattf/SPARK-1701-partition-over-slice-for-python-examples and squashes the following commits: 928a581 [Matthew Farrellee] [SPARK-1701] [PySpark] remove slice terminology from python examples --- examples/src/main/python/als.py | 12 ++++++------ examples/src/main/python/pi.py | 8 ++++---- examples/src/main/python/transitive_closure.py | 6 +++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 5b1fa4d997eeb..70b6146e39a87 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -54,7 +54,7 @@ def update(i, vec, mat, ratings): if __name__ == "__main__": """ - Usage: als [M] [U] [F] [iterations] [slices]" + Usage: als [M] [U] [F] [iterations] [partitions]" """ print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an @@ -66,10 +66,10 @@ def update(i, vec, mat, ratings): U = int(sys.argv[2]) if len(sys.argv) > 2 else 500 F = int(sys.argv[3]) if len(sys.argv) > 3 else 10 ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5 - slices = int(sys.argv[5]) if len(sys.argv) > 5 else 2 + partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2 - print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \ - (M, U, F, ITERATIONS, slices) + print "Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % \ + (M, U, F, ITERATIONS, partitions) R = matrix(rand(M, F)) * matrix(rand(U, F).T) ms = matrix(rand(M, F)) @@ -80,7 +80,7 @@ def update(i, vec, mat, ratings): usb = sc.broadcast(us) for i in range(ITERATIONS): - ms = sc.parallelize(range(M), slices) \ + ms = sc.parallelize(range(M), partitions) \ .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \ .collect() # collect() returns a list, so array ends up being @@ -88,7 +88,7 @@ def update(i, vec, mat, ratings): ms = matrix(np.array(ms)[:, :, 0]) msb = sc.broadcast(ms) - us = sc.parallelize(range(U), slices) \ + us = sc.parallelize(range(U), partitions) \ .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \ .collect() us = matrix(np.array(us)[:, :, 0]) diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index ee9036adfa281..a7c74e969cdb9 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -24,18 +24,18 @@ if __name__ == "__main__": """ - Usage: pi [slices] + Usage: pi [partitions] """ sc = SparkContext(appName="PythonPi") - slices = int(sys.argv[1]) if len(sys.argv) > 1 else 2 - n = 100000 * slices + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 + n = 100000 * partitions def f(_): x = random() * 2 - 1 y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n + 1), slices).map(f).reduce(add) + count = sc.parallelize(xrange(1, n + 1), partitions).map(f).reduce(add) print "Pi is roughly %f" % (4.0 * count / n) sc.stop() diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index bf331b542c438..00a281bfb6506 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -37,11 +37,11 @@ def generateGraph(): if __name__ == "__main__": """ - Usage: transitive_closure [slices] + Usage: transitive_closure [partitions] """ sc = SparkContext(appName="PythonTransitiveClosure") - slices = int(sys.argv[1]) if len(sys.argv) > 1 else 2 - tc = sc.parallelize(generateGraph(), slices).cache() + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 + tc = sc.parallelize(generateGraph(), partitions).cache() # Linear transitive closure: each round grows paths by one edge, # by joining the graph's edges with the already-discovered paths. From fce5e251d636c788cda91345867e0294280c074d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Sep 2014 15:01:11 -0700 Subject: [PATCH 041/315] [SPARK-3491] [MLlib] [PySpark] use pickle to serialize data in MLlib Currently, we serialize the data between JVM and Python case by case manually, this cannot scale to support so many APIs in MLlib. This patch will try to address this problem by serialize the data using pickle protocol, using Pyrolite library to serialize/deserialize in JVM. Pickle protocol can be easily extended to support customized class. All the modules are refactored to use this protocol. Known issues: There will be some performance regression (both CPU and memory, the serialized data increased) Author: Davies Liu Closes #2378 from davies/pickle_mllib and squashes the following commits: dffbba2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into pickle_mllib 810f97f [Davies Liu] fix equal of matrix 032cd62 [Davies Liu] add more type check and conversion for user_product bd738ab [Davies Liu] address comments e431377 [Davies Liu] fix cache of rdd, refactor 19d0967 [Davies Liu] refactor Picklers 2511e76 [Davies Liu] cleanup 1fccf1a [Davies Liu] address comments a2cc855 [Davies Liu] fix tests 9ceff73 [Davies Liu] test size of serialized Rating 44e0551 [Davies Liu] fix cache a379a81 [Davies Liu] fix pickle array in python2.7 df625c7 [Davies Liu] Merge commit '154d141' into pickle_mllib 154d141 [Davies Liu] fix autobatchedpickler 44736d7 [Davies Liu] speed up pickling array in Python 2.7 e1d1bfc [Davies Liu] refactor 708dc02 [Davies Liu] fix tests 9dcfb63 [Davies Liu] fix style 88034f0 [Davies Liu] rafactor, address comments 46a501e [Davies Liu] choose batch size automatically df19464 [Davies Liu] memorize the module and class name during pickleing f3506c5 [Davies Liu] Merge branch 'master' into pickle_mllib 722dd96 [Davies Liu] cleanup _common.py 0ee1525 [Davies Liu] remove outdated tests b02e34f [Davies Liu] remove _common.py 84c721d [Davies Liu] Merge branch 'master' into pickle_mllib 4d7963e [Davies Liu] remove muanlly serialization 6d26b03 [Davies Liu] fix tests c383544 [Davies Liu] classification f2a0856 [Davies Liu] mllib/regression d9f691f [Davies Liu] mllib/util cccb8b1 [Davies Liu] mllib/tree 8fe166a [Davies Liu] Merge branch 'pickle' into pickle_mllib aa2287e [Davies Liu] random f1544c4 [Davies Liu] refactor clustering 52d1350 [Davies Liu] use new protocol in mllib/stat b30ef35 [Davies Liu] use pickle to serialize data for mllib/recommendation f44f771 [Davies Liu] enable tests about array 3908f5c [Davies Liu] Merge branch 'master' into pickle c77c87b [Davies Liu] cleanup debugging code 60e4e2f [Davies Liu] support unpickle array.array for Python 2.6 --- .../apache/spark/api/python/PythonRDD.scala | 31 +- .../apache/spark/api/python/SerDeUtil.scala | 4 +- .../mllib/api/python/PythonMLLibAPI.scala | 487 ++++++--------- .../apache/spark/mllib/linalg/Matrices.scala | 10 +- .../MatrixFactorizationModel.scala | 15 - .../api/python/PythonMLLibAPISuite.scala | 44 +- python/epydoc.conf | 2 +- python/pyspark/context.py | 1 + python/pyspark/mllib/_common.py | 562 ------------------ python/pyspark/mllib/classification.py | 61 +- python/pyspark/mllib/clustering.py | 38 +- python/pyspark/mllib/linalg.py | 256 ++++++-- python/pyspark/mllib/random.py | 54 +- python/pyspark/mllib/recommendation.py | 69 ++- python/pyspark/mllib/regression.py | 105 ++-- python/pyspark/mllib/stat.py | 63 +- python/pyspark/mllib/tests.py | 99 +-- python/pyspark/mllib/tree.py | 167 +++--- python/pyspark/mllib/util.py | 43 +- python/pyspark/rdd.py | 10 +- python/pyspark/serializers.py | 36 ++ python/run-tests | 1 - 22 files changed, 891 insertions(+), 1267 deletions(-) delete mode 100644 python/pyspark/mllib/_common.py diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b345a8fa7c3..f9ff4ea6ca157 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -775,17 +775,36 @@ private[spark] object PythonRDD extends Logging { }.toJavaRDD() } + private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext(): Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + /** * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark. */ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => - val pickle = new Pickler - iter.map { row => - pickle.dumps(row) - } - } + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 6668797f5f8be..7903457b17e13 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -68,8 +68,8 @@ private[python] object SerDeUtil extends Logging { construct(args ++ Array("")) } else if (args.length == 2 && args(1).isInstanceOf[String]) { val typecode = args(0).asInstanceOf[String].charAt(0) - val data: String = args(1).asInstanceOf[String] - construct(typecode, machineCodes(typecode), data.getBytes("ISO-8859-1")) + val data: Array[Byte] = args(1).asInstanceOf[String].getBytes("ISO-8859-1") + construct(typecode, machineCodes(typecode), data) } else { super.construct(args) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fa0fa69f38634..9164c294ac7b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -17,16 +17,20 @@ package org.apache.spark.mllib.api.python -import java.nio.{ByteBuffer, ByteOrder} +import java.io.OutputStream import scala.collection.JavaConverters._ +import scala.language.existentials +import scala.reflect.ClassTag + +import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ @@ -40,11 +44,10 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils + /** * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. - * - * See python/pyspark/mllib/_common.py for the mutually agreed upon data format. */ @DeveloperApi class PythonMLLibAPI extends Serializable { @@ -60,18 +63,17 @@ class PythonMLLibAPI extends Serializable { def loadLabeledPoints( jsc: JavaSparkContext, path: String, - minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint) + minPartitions: Int): JavaRDD[LabeledPoint] = + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) - val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA) - val model = trainFunc(data, initialWeights) + val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] + val model = trainFunc(data.rdd, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.serializeDoubleVector(model.weights)) + ret.add(SerDe.dumps(model.weights)) ret.add(model.intercept: java.lang.Double) ret } @@ -80,7 +82,7 @@ class PythonMLLibAPI extends Serializable { * Java stub for Python mllib LinearRegressionWithSGD.train() */ def trainLinearRegressionModelWithSGD( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], numIterations: Int, stepSize: Double, miniBatchFraction: Double, @@ -106,7 +108,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( (data, initialWeights) => lrAlg.run(data, initialWeights), - dataBytesJRDD, + data, initialWeightsBA) } @@ -114,7 +116,7 @@ class PythonMLLibAPI extends Serializable { * Java stub for Python mllib LassoWithSGD.train() */ def trainLassoModelWithSGD( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], numIterations: Int, stepSize: Double, regParam: Double, @@ -129,7 +131,7 @@ class PythonMLLibAPI extends Serializable { regParam, miniBatchFraction, initialWeights), - dataBytesJRDD, + data, initialWeightsBA) } @@ -137,7 +139,7 @@ class PythonMLLibAPI extends Serializable { * Java stub for Python mllib RidgeRegressionWithSGD.train() */ def trainRidgeModelWithSGD( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], numIterations: Int, stepSize: Double, regParam: Double, @@ -152,7 +154,7 @@ class PythonMLLibAPI extends Serializable { regParam, miniBatchFraction, initialWeights), - dataBytesJRDD, + data, initialWeightsBA) } @@ -160,7 +162,7 @@ class PythonMLLibAPI extends Serializable { * Java stub for Python mllib SVMWithSGD.train() */ def trainSVMModelWithSGD( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], numIterations: Int, stepSize: Double, regParam: Double, @@ -186,7 +188,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( (data, initialWeights) => SVMAlg.run(data, initialWeights), - dataBytesJRDD, + data, initialWeightsBA) } @@ -194,7 +196,7 @@ class PythonMLLibAPI extends Serializable { * Java stub for Python mllib LogisticRegressionWithSGD.train() */ def trainLogisticRegressionModelWithSGD( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], numIterations: Int, stepSize: Double, miniBatchFraction: Double, @@ -220,7 +222,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( (data, initialWeights) => LogRegAlg.run(data, initialWeights), - dataBytesJRDD, + data, initialWeightsBA) } @@ -228,14 +230,13 @@ class PythonMLLibAPI extends Serializable { * Java stub for NaiveBayes.train() */ def trainNaiveBayes( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], lambda: Double): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) - val model = NaiveBayes.train(data, lambda) + val model = NaiveBayes.train(data.rdd, lambda) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels))) - ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi))) - ret.add(SerDe.serializeDoubleMatrix(model.theta)) + ret.add(Vectors.dense(model.labels)) + ret.add(Vectors.dense(model.pi)) + ret.add(model.theta) ret } @@ -243,16 +244,12 @@ class PythonMLLibAPI extends Serializable { * Java stub for Python mllib KMeans.train() */ def trainKMeansModel( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[Vector], k: Int, maxIterations: Int, runs: Int, - initializationMode: String): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes)) - val model = KMeans.train(data, k, maxIterations, runs, initializationMode) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) - ret + initializationMode: String): KMeansModel = { + KMeans.train(data.rdd, k, maxIterations, runs, initializationMode) } /** @@ -262,13 +259,12 @@ class PythonMLLibAPI extends Serializable { * the Py4J documentation. */ def trainALSModel( - ratingsBytesJRDD: JavaRDD[Array[Byte]], + ratings: JavaRDD[Rating], rank: Int, iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) - ALS.train(ratings, rank, iterations, lambda, blocks) + ALS.train(ratings.rdd, rank, iterations, lambda, blocks) } /** @@ -278,14 +274,13 @@ class PythonMLLibAPI extends Serializable { * exit; see the Py4J documentation. */ def trainImplicitALSModel( - ratingsBytesJRDD: JavaRDD[Array[Byte]], + ratingsJRDD: JavaRDD[Rating], rank: Int, iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) - ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) } /** @@ -293,11 +288,11 @@ class PythonMLLibAPI extends Serializable { * This stub returns a handle to the Java object instead of the content of the Java object. * Extra care needs to be taken in the Python code to ensure it gets freed on exit; * see the Py4J documentation. - * @param dataBytesJRDD Training data + * @param data Training data * @param categoricalFeaturesInfoJMap Categorical features info, as Java map */ def trainDecisionTreeModel( - dataBytesJRDD: JavaRDD[Array[Byte]], + data: JavaRDD[LabeledPoint], algoStr: String, numClasses: Int, categoricalFeaturesInfoJMap: java.util.Map[Int, Int], @@ -307,8 +302,6 @@ class PythonMLLibAPI extends Serializable { minInstancesPerNode: Int, minInfoGain: Double): DecisionTreeModel = { - val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) - val algo = Algo.fromString(algoStr) val impurity = Impurities.fromString(impurityStr) @@ -322,44 +315,15 @@ class PythonMLLibAPI extends Serializable { minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) - DecisionTree.train(data, strategy) - } - - /** - * Predict the label of the given data point. - * This is a Java stub for python DecisionTreeModel.predict() - * - * @param featuresBytes Serialized feature vector for data point - * @return predicted label - */ - def predictDecisionTreeModel( - model: DecisionTreeModel, - featuresBytes: Array[Byte]): Double = { - val features: Vector = SerDe.deserializeDoubleVector(featuresBytes) - model.predict(features) - } - - /** - * Predict the labels of the given data points. - * This is a Java stub for python DecisionTreeModel.predict() - * - * @param dataJRDD A JavaRDD with serialized feature vectors - * @return JavaRDD of serialized predictions - */ - def predictDecisionTreeModel( - model: DecisionTreeModel, - dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes)) - model.predict(data).map(SerDe.serializeDouble) + DecisionTree.train(data.rdd, strategy) } /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. */ - def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = { - val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_))) - new MultivariateStatisticalSummarySerialized(cStats) + def colStats(rdd: JavaRDD[Vector]): MultivariateStatisticalSummary = { + Statistics.colStats(rdd.rdd) } /** @@ -367,19 +331,15 @@ class PythonMLLibAPI extends Serializable { * Returns the correlation matrix serialized into a byte array understood by deserializers in * pyspark. */ - def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { - val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_)) - val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) - SerDe.serializeDoubleMatrix(SerDe.to2dArray(result)) + def corr(x: JavaRDD[Vector], method: String): Matrix = { + Statistics.corr(x.rdd, getCorrNameOrDefault(method)) } /** * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). */ - def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { - val xDeser = x.rdd.map(SerDe.deserializeDouble(_)) - val yDeser = y.rdd.map(SerDe.deserializeDouble(_)) - Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) + def corr(x: JavaRDD[Double], y: JavaRDD[Double], method: String): Double = { + Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method)) } // used by the corr methods to retrieve the name of the correlation method passed in via pyspark @@ -411,10 +371,10 @@ class PythonMLLibAPI extends Serializable { def uniformRDD(jsc: JavaSparkContext, size: Long, numPartitions: java.lang.Integer, - seed: java.lang.Long): JavaRDD[Array[Byte]] = { + seed: java.lang.Long): JavaRDD[Double] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) + RG.uniformRDD(jsc.sc, size, parts, s) } /** @@ -423,10 +383,10 @@ class PythonMLLibAPI extends Serializable { def normalRDD(jsc: JavaSparkContext, size: Long, numPartitions: java.lang.Integer, - seed: java.lang.Long): JavaRDD[Array[Byte]] = { + seed: java.lang.Long): JavaRDD[Double] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) + RG.normalRDD(jsc.sc, size, parts, s) } /** @@ -436,10 +396,10 @@ class PythonMLLibAPI extends Serializable { mean: Double, size: Long, numPartitions: java.lang.Integer, - seed: java.lang.Long): JavaRDD[Array[Byte]] = { + seed: java.lang.Long): JavaRDD[Double] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble) + RG.poissonRDD(jsc.sc, mean, size, parts, s) } /** @@ -449,10 +409,10 @@ class PythonMLLibAPI extends Serializable { numRows: Long, numCols: Int, numPartitions: java.lang.Integer, - seed: java.lang.Long): JavaRDD[Array[Byte]] = { + seed: java.lang.Long): JavaRDD[Vector] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s) } /** @@ -462,10 +422,10 @@ class PythonMLLibAPI extends Serializable { numRows: Long, numCols: Int, numPartitions: java.lang.Integer, - seed: java.lang.Long): JavaRDD[Array[Byte]] = { + seed: java.lang.Long): JavaRDD[Vector] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s) } /** @@ -476,259 +436,168 @@ class PythonMLLibAPI extends Serializable { numRows: Long, numCols: Int, numPartitions: java.lang.Integer, - seed: java.lang.Long): JavaRDD[Array[Byte]] = { + seed: java.lang.Long): JavaRDD[Vector] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s) } } /** - * :: DeveloperApi :: - * MultivariateStatisticalSummary with Vector fields serialized. + * SerDe utility functions for PythonMLLibAPI. */ -@DeveloperApi -class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary) - extends Serializable { +private[spark] object SerDe extends Serializable { - def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean) + val PYSPARK_PACKAGE = "pyspark.mllib" - def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance) + /** + * Base class used for pickle + */ + private[python] abstract class BasePickler[T: ClassTag] + extends IObjectPickler with IObjectConstructor { + + private val cls = implicitly[ClassTag[T]].runtimeClass + private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4) + private val name = cls.getSimpleName + + // register this to Pickler and Unpickler + def register(): Unit = { + Pickler.registerCustomPickler(this.getClass, this) + Pickler.registerCustomPickler(cls, this) + Unpickler.registerConstructor(module, name, this) + } - def count: Long = summary.count + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + name + "\n").getBytes()) + } else { + pickler.save(this) // it will be memorized by Pickler + saveState(obj, out, pickler) + out.write(Opcodes.REDUCE) + } + } + + private[python] def saveObjects(out: OutputStream, pickler: Pickler, objects: Any*) = { + if (objects.length == 0 || objects.length > 3) { + out.write(Opcodes.MARK) + } + objects.foreach(pickler.save(_)) + val code = objects.length match { + case 1 => Opcodes.TUPLE1 + case 2 => Opcodes.TUPLE2 + case 3 => Opcodes.TUPLE3 + case _ => Opcodes.TUPLE + } + out.write(code) + } - def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros) + private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) + } - def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max) + // Pickler for DenseVector + private[python] class DenseVectorPickler extends BasePickler[DenseVector] { - def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min) -} + def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + val vector: DenseVector = obj.asInstanceOf[DenseVector] + saveObjects(out, pickler, vector.toArray) + } -/** - * SerDe utility functions for PythonMLLibAPI. - */ -private[spark] object SerDe extends Serializable { - private val DENSE_VECTOR_MAGIC: Byte = 1 - private val SPARSE_VECTOR_MAGIC: Byte = 2 - private val DENSE_MATRIX_MAGIC: Byte = 3 - private val LABELED_POINT_MAGIC: Byte = 4 - - private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { - require(bytes.length - offset >= 5, "Byte array too short") - val magic = bytes(offset) - if (magic == DENSE_VECTOR_MAGIC) { - deserializeDenseVector(bytes, offset) - } else if (magic == SPARSE_VECTOR_MAGIC) { - deserializeSparseVector(bytes, offset) - } else { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") + def construct(args: Array[Object]): Object = { + require(args.length == 1) + if (args.length != 1) { + throw new PickleException("should be 1") + } + new DenseVector(args(0).asInstanceOf[Array[Double]]) } } - private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { - require(bytes.length - offset == 8, "Wrong size byte array for Double") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - bb.getDouble - } - - private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 5, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val length = bb.getInt() - require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) - val db = bb.asDoubleBuffer() - val ans = new Array[Double](length.toInt) - db.get(ans) - Vectors.dense(ans) - } - - private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 9, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val size = bb.getInt() - val nonZeros = bb.getInt() - require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) - val ib = bb.asIntBuffer() - val indices = new Array[Int](nonZeros) - ib.get(indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - val values = new Array[Double](nonZeros) - db.get(values) - Vectors.sparse(size, indices, values) - } + // Pickler for DenseMatrix + private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] { - /** - * Returns an 8-byte array for the input Double. - * - * Note: we currently do not use a magic byte for double for storage efficiency. - * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. - * The corresponding deserializer, deserializeDouble, needs to be modified as well if the - * serialization scheme changes. - */ - private[python] def serializeDouble(double: Double): Array[Byte] = { - val bytes = new Array[Byte](8) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putDouble(double) - bytes - } - - private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { - val len = doubles.length - val bytes = new Array[Byte](5 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_VECTOR_MAGIC) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(doubles) - bytes - } - - private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = { - val nonZeros = vector.indices.length - val bytes = new Array[Byte](9 + 12 * nonZeros) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(SPARSE_VECTOR_MAGIC) - bb.putInt(vector.size) - bb.putInt(nonZeros) - val ib = bb.asIntBuffer() - ib.put(vector.indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - db.put(vector.values) - bytes - } - - private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { - case s: SparseVector => - serializeSparseVector(s) - case _ => - serializeDenseVector(vector.toArray) - } - - private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { - val packetLength = bytes.length - if (packetLength < 9) { - throw new IllegalArgumentException("Byte array too short.") + def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] + saveObjects(out, pickler, m.numRows, m.numCols, m.values) } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - if (magic != DENSE_MATRIX_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") + + def construct(args: Array[Object]): Object = { + if (args.length != 3) { + throw new PickleException("should be 3") + } + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], + args(2).asInstanceOf[Array[Double]]) } - val rows = bb.getInt() - val cols = bb.getInt() - if (packetLength != 9 + 8 * rows * cols) { - throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + + // Pickler for SparseVector + private[python] class SparseVectorPickler extends BasePickler[SparseVector] { + + def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + val v: SparseVector = obj.asInstanceOf[SparseVector] + saveObjects(out, pickler, v.size, v.indices, v.values) } - val db = bb.asDoubleBuffer() - val ans = new Array[Array[Double]](rows.toInt) - for (i <- 0 until rows.toInt) { - ans(i) = new Array[Double](cols.toInt) - db.get(ans(i)) + + def construct(args: Array[Object]): Object = { + if (args.length != 3) { + throw new PickleException("should be 3") + } + new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]], + args(2).asInstanceOf[Array[Double]]) } - ans } - private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { - val rows = doubles.length - var cols = 0 - if (rows > 0) { - cols = doubles(0).length + // Pickler for LabeledPoint + private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] { + + def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + val point: LabeledPoint = obj.asInstanceOf[LabeledPoint] + saveObjects(out, pickler, point.label, point.features) } - val bytes = new Array[Byte](9 + 8 * rows * cols) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_MATRIX_MAGIC) - bb.putInt(rows) - bb.putInt(cols) - val db = bb.asDoubleBuffer() - for (i <- 0 until rows) { - db.put(doubles(i)) + + def construct(args: Array[Object]): Object = { + if (args.length != 2) { + throw new PickleException("should be 2") + } + new LabeledPoint(args(0).asInstanceOf[Double], args(1).asInstanceOf[Vector]) } - bytes } - private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { - val fb = serializeDoubleVector(p.features) - val bytes = new Array[Byte](1 + 8 + fb.length) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(LABELED_POINT_MAGIC) - bb.putDouble(p.label) - bb.put(fb) - bytes - } + // Pickler for Rating + private[python] class RatingPickler extends BasePickler[Rating] { - private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { - require(bytes.length >= 9, "Byte array too short") - val magic = bytes(0) - if (magic != LABELED_POINT_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") + def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { + val rating: Rating = obj.asInstanceOf[Rating] + saveObjects(out, pickler, rating.user, rating.product, rating.rating) } - val labelBytes = ByteBuffer.wrap(bytes, 1, 8) - labelBytes.order(ByteOrder.nativeOrder()) - val label = labelBytes.asDoubleBuffer().get(0) - LabeledPoint(label, deserializeDoubleVector(bytes, 9)) - } - // Reformat a Matrix into Array[Array[Double]] for serialization - private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { - val values = matrix.toArray - Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + def construct(args: Array[Object]): Object = { + if (args.length != 3) { + throw new PickleException("should be 3") + } + new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], + args(2).asInstanceOf[Double]) + } } + def initialize(): Unit = { + new DenseVectorPickler().register() + new DenseMatrixPickler().register() + new SparseVectorPickler().register() + new LabeledPointPickler().register() + new RatingPickler().register() + } - /** Unpack a Rating object from an array of bytes */ - private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = { - val bb = ByteBuffer.wrap(ratingBytes) - bb.order(ByteOrder.nativeOrder()) - val user = bb.getInt() - val product = bb.getInt() - val rating = bb.getDouble() - new Rating(user, product, rating) + def dumps(obj: AnyRef): Array[Byte] = { + new Pickler().dumps(obj) } - /** Unpack a tuple of Ints from an array of bytes */ - def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { - val bb = ByteBuffer.wrap(tupleBytes) - bb.order(ByteOrder.nativeOrder()) - val v1 = bb.getInt() - val v2 = bb.getInt() - (v1, v2) + def loads(bytes: Array[Byte]): AnyRef = { + new Unpickler().loads(bytes) } - /** - * Serialize a Rating object into an array of bytes. - * It can be deserialized using RatingDeserializer(). - * - * @param rate the Rating object to serialize - * @return - */ - def serializeRating(rate: Rating): Array[Byte] = { - val len = 3 - val bytes = new Array[Byte](4 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(rate.user.toDouble) - db.put(rate.product.toDouble) - db.put(rate.rating) - bytes + /* convert object into Tuple */ + def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = { + rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 5711532abcf80..4e87fe088ecc5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.linalg +import java.util.Arrays + import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} import org.apache.spark.util.random.XORShiftRandom -import java.util.Arrays - /** * Trait for a local matrix. */ @@ -106,6 +106,12 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) override def toArray: Array[Double] = values + override def equals(o: Any) = o match { + case m: DenseMatrix => + m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray) + case _ => false + } + private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) private[mllib] def apply(i: Int): Double = values(i) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 478c6485052b6..66b58ba770160 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -106,19 +106,4 @@ class MatrixFactorizationModel private[mllib] ( } scored.top(num)(Ordering.by(_._2)) } - - /** - * :: DeveloperApi :: - * Predict the rating of many users for many products. - * This is a Java stub for python predictAll() - * - * @param usersProductsJRDD A JavaRDD with serialized tuples (user, product) - * @return JavaRDD of serialized Rating objects. - */ - @DeveloperApi - def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes)) - predict(usersProducts).map(rate => SerDe.serializeRating(rate)) - } - } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 092d67bbc5238..db8ed62fa46ce 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.mllib.api.python import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{Matrices, Vectors} +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.recommendation.Rating class PythonMLLibAPISuite extends FunSuite { - test("vector serialization") { + SerDe.initialize() + + test("pickle vector") { val vectors = Seq( Vectors.dense(Array.empty[Double]), Vectors.dense(0.0), @@ -33,14 +36,13 @@ class PythonMLLibAPISuite extends FunSuite { Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), Vectors.sparse(2, Array(1), Array(-2.0))) vectors.foreach { v => - val bytes = SerDe.serializeDoubleVector(v) - val u = SerDe.deserializeDoubleVector(bytes) + val u = SerDe.loads(SerDe.dumps(v)) assert(u.getClass === v.getClass) assert(u === v) } } - test("labeled point serialization") { + test("pickle labeled point") { val points = Seq( LabeledPoint(0.0, Vectors.dense(Array.empty[Double])), LabeledPoint(1.0, Vectors.dense(0.0)), @@ -49,34 +51,44 @@ class PythonMLLibAPISuite extends FunSuite { LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) points.foreach { p => - val bytes = SerDe.serializeLabeledPoint(p) - val q = SerDe.deserializeLabeledPoint(bytes) + val q = SerDe.loads(SerDe.dumps(p)).asInstanceOf[LabeledPoint] assert(q.label === p.label) assert(q.features.getClass === p.features.getClass) assert(q.features === p.features) } } - test("double serialization") { + test("pickle double") { for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { - val bytes = SerDe.serializeDouble(x) - val deser = SerDe.deserializeDouble(bytes) + val deser = SerDe.loads(SerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double] // We use `equals` here for comparison because we cannot use `==` for NaN assert(x.equals(deser)) } } - test("matrix to 2D array") { + test("pickle matrix") { val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) val matrix = Matrices.dense(2, 3, values) - val arr = SerDe.to2dArray(matrix) - val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) - assert(arr === expected) + val nm = SerDe.loads(SerDe.dumps(matrix)).asInstanceOf[DenseMatrix] + assert(matrix === nm) // Test conversion for empty matrix val empty = Array[Double]() val emptyMatrix = Matrices.dense(0, 0, empty) - val empty2D = SerDe.to2dArray(emptyMatrix) - assert(empty2D === Array[Array[Double]]()) + val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix] + assert(emptyMatrix == ne) + } + + test("pickle rating") { + val rat = new Rating(1, 2, 3.0) + val rat2 = SerDe.loads(SerDe.dumps(rat)).asInstanceOf[Rating] + assert(rat == rat2) + + // Test name of class only occur once + val rats = (1 to 10).map(x => new Rating(x, x + 1, x + 3.0)).toArray + val bytes = SerDe.dumps(rats) + assert(bytes.toString.split("Rating").length == 1) + assert(bytes.length / 10 < 25) // 25 bytes per rating + } } diff --git a/python/epydoc.conf b/python/epydoc.conf index 51c0faf359939..8593e08deda19 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -34,5 +34,5 @@ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests - pyspark.rddsampler pyspark.daemon pyspark.mllib._common + pyspark.rddsampler pyspark.daemon pyspark.mllib.tests pyspark.shuffle diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a17f2c1203d36..064a24bff539c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -211,6 +211,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile SparkContext._jvm.SerDeUtil.initialize() + SparkContext._jvm.SerDe.initialize() if instance: if (SparkContext._active_spark_context and diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py deleted file mode 100644 index 68f6033616726..0000000000000 --- a/python/pyspark/mllib/_common.py +++ /dev/null @@ -1,562 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import struct -import sys -import numpy -from numpy import ndarray, float64, int64, int32, array_equal, array -from pyspark import SparkContext, RDD -from pyspark.mllib.linalg import SparseVector -from pyspark.serializers import FramedSerializer - - -""" -Common utilities shared throughout MLlib, primarily for dealing with -different data types. These include: -- Serialization utilities to / from byte arrays that Java can handle -- Serializers for other data types, like ALS Rating objects -- Common methods for linear models -- Methods to deal with the different vector types we support, such as - SparseVector and scipy.sparse matrices. -""" - - -# Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, -# such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. - -_have_scipy = False -_scipy_issparse = None -try: - import scipy.sparse - _have_scipy = True - _scipy_issparse = scipy.sparse.issparse -except: - # No SciPy in environment, but that's okay - pass - - -# Serialization functions to and from Scala. These use the following formats, understood -# by the PythonMLLibAPI class in Scala: -# -# Dense double vector format: -# -# [1-byte 1] [4-byte length] [length*8 bytes of data] -# -# Sparse double vector format: -# -# [1-byte 2] [4-byte length] [4-byte nonzeros] [nonzeros*4 bytes of indices] \ -# [nonzeros*8 bytes of values] -# -# Double matrix format: -# -# [1-byte 3] [4-byte rows] [4-byte cols] [rows*cols*8 bytes of data] -# -# LabeledPoint format: -# -# [1-byte 4] [8-byte label] [dense or sparse vector] -# -# This is all in machine-endian. That means that the Java interpreter and the -# Python interpreter must agree on what endian the machine is. - - -DENSE_VECTOR_MAGIC = 1 -SPARSE_VECTOR_MAGIC = 2 -DENSE_MATRIX_MAGIC = 3 -LABELED_POINT_MAGIC = 4 - - -# Workaround for SPARK-2954: before Python 2.7, struct.unpack couldn't unpack bytearray()s. -if sys.version_info[:2] <= (2, 6): - def _unpack(fmt, string): - return struct.unpack(fmt, buffer(string)) -else: - _unpack = struct.unpack - - -def _deserialize_numpy_array(shape, ba, offset, dtype=float64): - """ - Deserialize a numpy array of the given type from an offset in - bytearray ba, assigning it the given shape. - - >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0]) - >>> array_equal(x, _deserialize_numpy_array(x.shape, x.data, 0)) - True - >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2) - >>> array_equal(x, _deserialize_numpy_array(x.shape, x.data, 0)) - True - >>> x = array([1, 2, 3], dtype=int32) - >>> array_equal(x, _deserialize_numpy_array(x.shape, x.data, 0, dtype=int32)) - True - """ - ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype=dtype, order='C') - return ar.copy() - - -def _serialize_double(d): - """ - Serialize a double (float or numpy.float64) into a mutually understood format. - """ - if type(d) == float or type(d) == float64 or type(d) == int or type(d) == long: - d = float64(d) - ba = bytearray(8) - _copyto(d, buffer=ba, offset=0, shape=[1], dtype=float64) - return ba - else: - raise TypeError("_serialize_double called on non-float input") - - -def _serialize_double_vector(v): - """ - Serialize a double vector into a mutually understood format. - - Note: we currently do not use a magic byte for double for storage - efficiency. This should be reconsidered when we add Ser/De for other - 8-byte types (e.g. Long), for safety. The corresponding deserializer, - _deserialize_double, needs to be modified as well if the serialization - scheme changes. - - >>> x = array([1,2,3]) - >>> y = _deserialize_double_vector(_serialize_double_vector(x)) - >>> array_equal(y, array([1.0, 2.0, 3.0])) - True - """ - v = _convert_vector(v) - if type(v) == ndarray: - return _serialize_dense_vector(v) - elif type(v) == SparseVector: - return _serialize_sparse_vector(v) - else: - raise TypeError("_serialize_double_vector called on a %s; " - "wanted ndarray or SparseVector" % type(v)) - - -def _serialize_dense_vector(v): - """Serialize a dense vector given as a NumPy array.""" - if v.ndim != 1: - raise TypeError("_serialize_double_vector called on a %ddarray; " - "wanted a 1darray" % v.ndim) - if v.dtype != float64: - if numpy.issubdtype(v.dtype, numpy.complex): - raise TypeError("_serialize_double_vector called on an ndarray of %s; " - "wanted ndarray of float64" % v.dtype) - v = v.astype(float64) - length = v.shape[0] - ba = bytearray(5 + 8 * length) - ba[0] = DENSE_VECTOR_MAGIC - length_bytes = ndarray(shape=[1], buffer=ba, offset=1, dtype=int32) - length_bytes[0] = length - _copyto(v, buffer=ba, offset=5, shape=[length], dtype=float64) - return ba - - -def _serialize_sparse_vector(v): - """Serialize a pyspark.mllib.linalg.SparseVector.""" - nonzeros = len(v.indices) - ba = bytearray(9 + 12 * nonzeros) - ba[0] = SPARSE_VECTOR_MAGIC - header = ndarray(shape=[2], buffer=ba, offset=1, dtype=int32) - header[0] = v.size - header[1] = nonzeros - _copyto(v.indices, buffer=ba, offset=9, shape=[nonzeros], dtype=int32) - values_offset = 9 + 4 * nonzeros - _copyto(v.values, buffer=ba, offset=values_offset, shape=[nonzeros], dtype=float64) - return ba - - -def _deserialize_double(ba, offset=0): - """Deserialize a double from a mutually understood format. - - >>> import sys - >>> _deserialize_double(_serialize_double(123.0)) == 123.0 - True - >>> _deserialize_double(_serialize_double(float64(0.0))) == 0.0 - True - >>> _deserialize_double(_serialize_double(1)) == 1.0 - True - >>> _deserialize_double(_serialize_double(1L)) == 1.0 - True - >>> x = sys.float_info.max - >>> _deserialize_double(_serialize_double(sys.float_info.max)) == x - True - >>> y = float64(sys.float_info.max) - >>> _deserialize_double(_serialize_double(sys.float_info.max)) == y - True - """ - if type(ba) != bytearray: - raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba)) - if len(ba) - offset != 8: - raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb) - return _unpack("d", ba[offset:])[0] - - -def _deserialize_double_vector(ba, offset=0): - """Deserialize a double vector from a mutually understood format. - - >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]) - >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x))) - True - >>> s = SparseVector(4, [1, 3], [3.0, 5.5]) - >>> s == _deserialize_double_vector(_serialize_double_vector(s)) - True - """ - if type(ba) != bytearray: - raise TypeError("_deserialize_double_vector called on a %s; " - "wanted bytearray" % type(ba)) - nb = len(ba) - offset - if nb < 5: - raise TypeError("_deserialize_double_vector called on a %d-byte array, " - "which is too short" % nb) - if ba[offset] == DENSE_VECTOR_MAGIC: - return _deserialize_dense_vector(ba, offset) - elif ba[offset] == SPARSE_VECTOR_MAGIC: - return _deserialize_sparse_vector(ba, offset) - else: - raise TypeError("_deserialize_double_vector called on bytearray " - "with wrong magic") - - -def _deserialize_dense_vector(ba, offset=0): - """Deserialize a dense vector into a numpy array.""" - nb = len(ba) - offset - if nb < 5: - raise TypeError("_deserialize_dense_vector called on a %d-byte array, " - "which is too short" % nb) - length = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=int32)[0] - if nb < 8 * length + 5: - raise TypeError("_deserialize_dense_vector called on bytearray " - "with wrong length") - return _deserialize_numpy_array([length], ba, offset + 5) - - -def _deserialize_sparse_vector(ba, offset=0): - """Deserialize a sparse vector into a MLlib SparseVector object.""" - nb = len(ba) - offset - if nb < 9: - raise TypeError("_deserialize_sparse_vector called on a %d-byte array, " - "which is too short" % nb) - header = ndarray(shape=[2], buffer=ba, offset=offset + 1, dtype=int32) - size = header[0] - nonzeros = header[1] - if nb < 9 + 12 * nonzeros: - raise TypeError("_deserialize_sparse_vector called on bytearray " - "with wrong length") - indices = _deserialize_numpy_array([nonzeros], ba, offset + 9, dtype=int32) - values = _deserialize_numpy_array([nonzeros], ba, offset + 9 + 4 * nonzeros, dtype=float64) - return SparseVector(int(size), indices, values) - - -def _serialize_double_matrix(m): - """Serialize a double matrix into a mutually understood format.""" - if (type(m) == ndarray and m.ndim == 2): - if m.dtype != float64: - if numpy.issubdtype(m.dtype, numpy.complex): - raise TypeError("_serialize_double_matrix called on an ndarray of %s; " - "wanted ndarray of float64" % m.dtype) - m = m.astype(float64) - rows = m.shape[0] - cols = m.shape[1] - ba = bytearray(9 + 8 * rows * cols) - ba[0] = DENSE_MATRIX_MAGIC - lengths = ndarray(shape=[3], buffer=ba, offset=1, dtype=int32) - lengths[0] = rows - lengths[1] = cols - _copyto(m, buffer=ba, offset=9, shape=[rows, cols], dtype=float64) - return ba - else: - raise TypeError("_serialize_double_matrix called on a " - "non-double-matrix") - - -def _deserialize_double_matrix(ba): - """Deserialize a double matrix from a mutually understood format.""" - if type(ba) != bytearray: - raise TypeError("_deserialize_double_matrix called on a %s; " - "wanted bytearray" % type(ba)) - if len(ba) < 9: - raise TypeError("_deserialize_double_matrix called on a %d-byte array, " - "which is too short" % len(ba)) - if ba[0] != DENSE_MATRIX_MAGIC: - raise TypeError("_deserialize_double_matrix called on bytearray " - "with wrong magic") - lengths = ndarray(shape=[2], buffer=ba, offset=1, dtype=int32) - rows = lengths[0] - cols = lengths[1] - if (len(ba) != 8 * rows * cols + 9): - raise TypeError("_deserialize_double_matrix called on bytearray " - "with wrong length") - return _deserialize_numpy_array([rows, cols], ba, 9) - - -def _serialize_labeled_point(p): - """ - Serialize a LabeledPoint with a features vector of any type. - - >>> from pyspark.mllib.regression import LabeledPoint - >>> dp0 = LabeledPoint(0.5, array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])) - >>> dp1 = _deserialize_labeled_point(_serialize_labeled_point(dp0)) - >>> dp1.label == dp0.label - True - >>> array_equal(dp1.features, dp0.features) - True - >>> sp0 = LabeledPoint(0.0, SparseVector(4, [1, 3], [3.0, 5.5])) - >>> sp1 = _deserialize_labeled_point(_serialize_labeled_point(sp0)) - >>> sp1.label == sp1.label - True - >>> sp1.features == sp0.features - True - """ - from pyspark.mllib.regression import LabeledPoint - serialized_features = _serialize_double_vector(p.features) - header = bytearray(9) - header[0] = LABELED_POINT_MAGIC - header_float = ndarray(shape=[1], buffer=header, offset=1, dtype=float64) - header_float[0] = p.label - return header + serialized_features - - -def _deserialize_labeled_point(ba, offset=0): - """Deserialize a LabeledPoint from a mutually understood format.""" - from pyspark.mllib.regression import LabeledPoint - if type(ba) != bytearray: - raise TypeError("Expecting a bytearray but got %s" % type(ba)) - if ba[offset] != LABELED_POINT_MAGIC: - raise TypeError("Expecting magic number %d but got %d" % (LABELED_POINT_MAGIC, ba[0])) - label = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=float64)[0] - features = _deserialize_double_vector(ba, offset + 9) - return LabeledPoint(label, features) - - -def _copyto(array, buffer, offset, shape, dtype): - """ - Copy the contents of a vector to a destination bytearray at the - given offset. - - TODO: In the future this could use numpy.copyto on NumPy 1.7+, but - we should benchmark that to see whether it provides a benefit. - """ - temp_array = ndarray(shape=shape, buffer=buffer, offset=offset, dtype=dtype, order='C') - temp_array[...] = array - - -def _get_unmangled_rdd(data, serializer, cache=True): - """ - :param cache: If True, the serialized RDD is cached. (default = True) - WARNING: Users should unpersist() this later! - """ - dataBytes = data.map(serializer) - dataBytes._bypass_serializer = True - if cache: - dataBytes.cache() - return dataBytes - - -def _get_unmangled_double_vector_rdd(data, cache=True): - """ - Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of - _serialized_double_vectors. - :param cache: If True, the serialized RDD is cached. (default = True) - WARNING: Users should unpersist() this later! - """ - return _get_unmangled_rdd(data, _serialize_double_vector, cache) - - -def _get_unmangled_labeled_point_rdd(data, cache=True): - """ - Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points. - :param cache: If True, the serialized RDD is cached. (default = True) - WARNING: Users should unpersist() this later! - """ - return _get_unmangled_rdd(data, _serialize_labeled_point, cache) - - -# Common functions for dealing with and training linear models - -def _linear_predictor_typecheck(x, coeffs): - """ - Check that x is a one-dimensional vector of the right shape. - This is a temporary hackaround until we actually implement bulk predict. - """ - x = _convert_vector(x) - if type(x) == ndarray: - if x.ndim == 1: - if x.shape != coeffs.shape: - raise RuntimeError("Got array of %d elements; wanted %d" % ( - numpy.shape(x)[0], coeffs.shape[0])) - else: - raise RuntimeError("Bulk predict not yet supported.") - elif type(x) == SparseVector: - if x.size != coeffs.shape[0]: - raise RuntimeError("Got sparse vector of size %d; wanted %d" % ( - x.size, coeffs.shape[0])) - elif isinstance(x, RDD): - raise RuntimeError("Bulk predict not yet supported.") - else: - raise TypeError("Argument of type " + type(x).__name__ + " unsupported") - - -# If we weren't given initial weights, take a zero vector of the appropriate -# length. -def _get_initial_weights(initial_weights, data): - if initial_weights is None: - initial_weights = _convert_vector(data.first().features) - if type(initial_weights) == ndarray: - if initial_weights.ndim != 1: - raise TypeError("At least one data element has " - + initial_weights.ndim + " dimensions, which is not 1") - initial_weights = numpy.zeros([initial_weights.shape[0]]) - elif type(initial_weights) == SparseVector: - initial_weights = numpy.zeros([initial_weights.size]) - return initial_weights - - -# train_func should take two parameters, namely data and initial_weights, and -# return the result of a call to the appropriate JVM stub. -# _regression_train_wrapper is responsible for setup and error checking. -def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): - initial_weights = _get_initial_weights(initial_weights, data) - dataBytes = _get_unmangled_labeled_point_rdd(data) - ans = train_func(dataBytes, _serialize_double_vector(initial_weights)) - if len(ans) != 2: - raise RuntimeError("JVM call result had unexpected length") - elif type(ans[0]) != bytearray: - raise RuntimeError("JVM call result had first element of type " - + type(ans[0]).__name__ + " which is not bytearray") - elif type(ans[1]) != float: - raise RuntimeError("JVM call result had second element of type " - + type(ans[0]).__name__ + " which is not float") - return klass(_deserialize_double_vector(ans[0]), ans[1]) - - -# Functions for serializing ALS Rating objects and tuples - -def _serialize_rating(r): - ba = bytearray(16) - intpart = ndarray(shape=[2], buffer=ba, dtype=int32) - doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8) - intpart[0], intpart[1], doublepart[0] = r - return ba - - -class RatingDeserializer(FramedSerializer): - - def loads(self, string): - res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4) - return int(res[0]), int(res[1]), res[2] - - def load_stream(self, stream): - while True: - try: - yield self._read_with_length(stream) - except struct.error: - return - except EOFError: - return - - -def _serialize_tuple(t): - ba = bytearray(8) - intpart = ndarray(shape=[2], buffer=ba, dtype=int32) - intpart[0], intpart[1] = t - return ba - - -# Vector math functions that support all of our vector types - -def _convert_vector(vec): - """ - Convert a vector to a format we support internally. This does - the following: - - * For dense NumPy vectors (ndarray), returns them as is - * For our SparseVector class, returns that as is - * For Python lists, converts them to NumPy vectors - * For scipy.sparse.*_matrix column vectors, converts them to - our own SparseVector type. - - This should be called before passing any data to our algorithms - or attempting to serialize it to Java. - """ - if type(vec) == ndarray or type(vec) == SparseVector: - return vec - elif type(vec) == list: - return array(vec, dtype=float64) - elif _have_scipy: - if _scipy_issparse(vec): - assert vec.shape[1] == 1, "Expected column vector" - csc = vec.tocsc() - return SparseVector(vec.shape[0], csc.indices, csc.data) - raise TypeError("Expected NumPy array, SparseVector, or scipy.sparse matrix") - - -def _squared_distance(v1, v2): - """ - Squared distance of two NumPy or sparse vectors. - - >>> dense1 = array([1., 2.]) - >>> sparse1 = SparseVector(2, [0, 1], [1., 2.]) - >>> dense2 = array([2., 1.]) - >>> sparse2 = SparseVector(2, [0, 1], [2., 1.]) - >>> _squared_distance(dense1, dense2) - 2.0 - >>> _squared_distance(dense1, sparse2) - 2.0 - >>> _squared_distance(sparse1, dense2) - 2.0 - >>> _squared_distance(sparse1, sparse2) - 2.0 - """ - v1 = _convert_vector(v1) - v2 = _convert_vector(v2) - if type(v1) == ndarray and type(v2) == ndarray: - diff = v1 - v2 - return numpy.dot(diff, diff) - elif type(v1) == ndarray: - return v2.squared_distance(v1) - else: - return v1.squared_distance(v2) - - -def _dot(vec, target): - """ - Compute the dot product of a vector of the types we support - (Numpy array, list, SparseVector, or SciPy sparse) and a target - NumPy array that is either 1- or 2-dimensional. Equivalent to - calling numpy.dot of the two vectors, but for SciPy ones, we - have to transpose them because they're column vectors. - """ - if type(vec) == ndarray: - return numpy.dot(vec, target) - elif type(vec) == SparseVector: - return vec.dot(target) - elif type(vec) == list: - return numpy.dot(_convert_vector(vec), target) - else: - return vec.transpose().dot(target)[0] - - -def _test(): - import doctest - globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 71ab46b61d7fa..ac142fb49a90c 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -15,19 +15,14 @@ # limitations under the License. # +from math import exp + import numpy +from numpy import array -from numpy import array, shape -from pyspark import SparkContext -from pyspark.mllib._common import \ - _dot, _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ - _linear_predictor_typecheck, _get_unmangled_labeled_point_rdd -from pyspark.mllib.linalg import SparseVector -from pyspark.mllib.regression import LabeledPoint, LinearModel -from math import exp, log +from pyspark import SparkContext, PickleSerializer +from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel', @@ -67,8 +62,7 @@ class LogisticRegressionModel(LinearModel): """ def predict(self, x): - _linear_predictor_typecheck(x, self._coeff) - margin = _dot(x, self._coeff) + self._intercept + margin = self.weights.dot(x) + self._intercept if margin > 0: prob = 1 / (1 + exp(-margin)) else: @@ -81,7 +75,7 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType=None, intercept=False): + initialWeights=None, regParam=1.0, regType="none", intercept=False): """ Train a logistic regression model on the given data. @@ -106,11 +100,12 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, are activated or not). """ sc = data.context - if regType is None: - regType = "none" - train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) - return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data, + + def train(jdata, i): + return sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( + jdata, iterations, step, miniBatchFraction, i, regParam, regType, intercept) + + return _regression_train_wrapper(sc, train, LogisticRegressionModel, data, initialWeights) @@ -141,8 +136,7 @@ class SVMModel(LinearModel): """ def predict(self, x): - _linear_predictor_typecheck(x, self._coeff) - margin = _dot(x, self._coeff) + self._intercept + margin = self.weights.dot(x) + self.intercept return 1 if margin >= 0 else 0 @@ -150,7 +144,7 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): + miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False): """ Train a support vector machine on the given data. @@ -175,11 +169,12 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, are activated or not). """ sc = data.context - if regType is None: - regType = "none" - train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) - return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights) + + def train(jrdd, i): + return sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( + jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) + + return _regression_train_wrapper(sc, train, SVMModel, data, initialWeights) class NaiveBayesModel(object): @@ -220,7 +215,8 @@ def __init__(self, labels, pi, theta): def predict(self, x): """Return the most likely class for a data vector x""" - return self.labels[numpy.argmax(self.pi + _dot(x, self.theta.transpose()))] + x = _convert_to_vector(x) + return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] class NaiveBayes(object): @@ -242,12 +238,9 @@ def train(cls, data, lambda_=1.0): @param lambda_: The smoothing parameter """ sc = data.context - dataBytes = _get_unmangled_labeled_point_rdd(data) - ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_) - return NaiveBayesModel( - _deserialize_double_vector(ans[0]), - _deserialize_double_vector(ans[1]), - _deserialize_double_matrix(ans[2])) + jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) + labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist))) + return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) def _test(): diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index f3e952a1d842a..12c56022717a5 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,15 +15,9 @@ # limitations under the License. # -from numpy import array, dot -from math import sqrt from pyspark import SparkContext -from pyspark.mllib._common import \ - _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _squared_distance, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper -from pyspark.mllib.linalg import SparseVector +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -32,6 +26,7 @@ class KMeansModel(object): """A clustering model derived from the k-means method. + >>> from numpy import array >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) >>> model = KMeans.train( ... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") @@ -71,8 +66,9 @@ def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 best_distance = float("inf") - for i in range(0, len(self.centers)): - distance = _squared_distance(x, self.centers[i]) + x = _convert_to_vector(x) + for i in xrange(len(self.centers)): + distance = x.squared_distance(self.centers[i]) if distance < best_distance: best = i best_distance = distance @@ -82,19 +78,17 @@ def predict(self, x): class KMeans(object): @classmethod - def train(cls, data, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - sc = data.context - dataBytes = _get_unmangled_double_vector_rdd(data) - ans = sc._jvm.PythonMLLibAPI().trainKMeansModel( - dataBytes._jrdd, k, maxIterations, runs, initializationMode) - if len(ans) != 1: - raise RuntimeError("JVM call result had unexpected length") - elif type(ans[0]) != bytearray: - raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray") - matrix = _deserialize_double_matrix(ans[0]) - return KMeansModel([row for row in matrix]) + sc = rdd.context + ser = PickleSerializer() + # cache serialized data to avoid objects over head in JVM + cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() + model = sc._jvm.PythonMLLibAPI().trainKMeansModel( + cached._to_java_object_rdd(), k, maxIterations, runs, initializationMode) + bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) + centers = ser.loads(str(bytes)) + return KMeansModel([c.toArray() for c in centers]) def _test(): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index e69051c104e37..0a5dcaac55e46 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -23,14 +23,148 @@ SciPy is available in their environment. """ -import numpy -from numpy import array, array_equal, ndarray, float64, int32 +import sys +import array +import copy_reg +import numpy as np -__all__ = ['SparseVector', 'Vectors'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] -class SparseVector(object): +if sys.version_info[:2] == (2, 7): + # speed up pickling array in Python 2.7 + def fast_pickle_array(ar): + return array.array, (ar.typecode, ar.tostring()) + copy_reg.pickle(array.array, fast_pickle_array) + + +# Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, +# such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. + +try: + import scipy.sparse + _have_scipy = True +except: + # No SciPy in environment, but that's okay + _have_scipy = False + + +def _convert_to_vector(l): + if isinstance(l, Vector): + return l + elif type(l) in (array.array, np.array, np.ndarray, list, tuple): + return DenseVector(l) + elif _have_scipy and scipy.sparse.issparse(l): + assert l.shape[1] == 1, "Expected column vector" + csc = l.tocsc() + return SparseVector(l.shape[0], csc.indices, csc.data) + else: + raise TypeError("Cannot convert type %s into Vector" % type(l)) + + +class Vector(object): + """ + Abstract class for DenseVector and SparseVector + """ + def toArray(self): + """ + Convert the vector into an numpy.ndarray + :return: numpy.ndarray + """ + raise NotImplementedError + + +class DenseVector(Vector): + def __init__(self, ar): + if not isinstance(ar, array.array): + ar = array.array('d', ar) + self.array = ar + + def __reduce__(self): + return DenseVector, (self.array,) + + def dot(self, other): + """ + Compute the dot product of two Vectors. We support + (Numpy array, list, SparseVector, or SciPy sparse) + and a target NumPy array that is either 1- or 2-dimensional. + Equivalent to calling numpy.dot of the two vectors. + + >>> dense = DenseVector(array.array('d', [1., 2.])) + >>> dense.dot(dense) + 5.0 + >>> dense.dot(SparseVector(2, [0, 1], [2., 1.])) + 4.0 + >>> dense.dot(range(1, 3)) + 5.0 + >>> dense.dot(np.array(range(1, 3))) + 5.0 + """ + if isinstance(other, SparseVector): + return other.dot(self) + elif _have_scipy and scipy.sparse.issparse(other): + return other.transpose().dot(self.toArray())[0] + elif isinstance(other, Vector): + return np.dot(self.toArray(), other.toArray()) + else: + return np.dot(self.toArray(), other) + + def squared_distance(self, other): + """ + Squared distance of two Vectors. + + >>> dense1 = DenseVector(array.array('d', [1., 2.])) + >>> dense1.squared_distance(dense1) + 0.0 + >>> dense2 = np.array([2., 1.]) + >>> dense1.squared_distance(dense2) + 2.0 + >>> dense3 = [2., 1.] + >>> dense1.squared_distance(dense3) + 2.0 + >>> sparse1 = SparseVector(2, [0, 1], [2., 1.]) + >>> dense1.squared_distance(sparse1) + 2.0 + """ + if isinstance(other, SparseVector): + return other.squared_distance(self) + elif _have_scipy and scipy.sparse.issparse(other): + return _convert_to_vector(other).squared_distance(self) + + if isinstance(other, Vector): + other = other.toArray() + elif not isinstance(other, np.ndarray): + other = np.array(other) + diff = self.toArray() - other + return np.dot(diff, diff) + + def toArray(self): + return np.array(self.array) + + def __getitem__(self, item): + return self.array[item] + + def __len__(self): + return len(self.array) + + def __str__(self): + return "[" + ",".join([str(v) for v in self.array]) + "]" + + def __repr__(self): + return "DenseVector(%r)" % self.array + + def __eq__(self, other): + return isinstance(other, DenseVector) and self.array == other.array + + def __ne__(self, other): + return not self == other + + def __getattr__(self, item): + return getattr(self.array, item) + + +class SparseVector(Vector): """ A simple sparse vector class for passing data to MLlib. Users may @@ -61,16 +195,19 @@ def __init__(self, size, *args): if type(pairs) == dict: pairs = pairs.items() pairs = sorted(pairs) - self.indices = array([p[0] for p in pairs], dtype=int32) - self.values = array([p[1] for p in pairs], dtype=float64) + self.indices = array.array('i', [p[0] for p in pairs]) + self.values = array.array('d', [p[1] for p in pairs]) else: assert len(args[0]) == len(args[1]), "index and value arrays not same length" - self.indices = array(args[0], dtype=int32) - self.values = array(args[1], dtype=float64) + self.indices = array.array('i', args[0]) + self.values = array.array('d', args[1]) for i in xrange(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: raise TypeError("indices array must be sorted") + def __reduce__(self): + return (SparseVector, (self.size, self.indices, self.values)) + def dot(self, other): """ Dot product with a SparseVector or 1- or 2-dimensional Numpy array. @@ -78,15 +215,15 @@ def dot(self, other): >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) >>> a.dot(a) 25.0 - >>> a.dot(array([1., 2., 3., 4.])) + >>> a.dot(array.array('d', [1., 2., 3., 4.])) 22.0 >>> b = SparseVector(4, [2, 4], [1.0, 2.0]) >>> a.dot(b) 0.0 - >>> a.dot(array([[1, 1], [2, 2], [3, 3], [4, 4]])) + >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]])) array([ 22., 22.]) """ - if type(other) == ndarray: + if type(other) == np.ndarray: if other.ndim == 1: result = 0.0 for i in xrange(len(self.indices)): @@ -94,10 +231,17 @@ def dot(self, other): return result elif other.ndim == 2: results = [self.dot(other[:, i]) for i in xrange(other.shape[1])] - return array(results) + return np.array(results) else: raise Exception("Cannot call dot with %d-dimensional array" % other.ndim) - else: + + elif type(other) in (array.array, DenseVector): + result = 0.0 + for i in xrange(len(self.indices)): + result += self.values[i] * other[self.indices[i]] + return result + + elif type(other) is SparseVector: result = 0.0 i, j = 0, 0 while i < len(self.indices) and j < len(other.indices): @@ -110,6 +254,8 @@ def dot(self, other): else: j += 1 return result + else: + return self.dot(_convert_to_vector(other)) def squared_distance(self, other): """ @@ -118,7 +264,9 @@ def squared_distance(self, other): >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) >>> a.squared_distance(a) 0.0 - >>> a.squared_distance(array([1., 2., 3., 4.])) + >>> a.squared_distance(array.array('d', [1., 2., 3., 4.])) + 11.0 + >>> a.squared_distance(np.array([1., 2., 3., 4.])) 11.0 >>> b = SparseVector(4, [2, 4], [1.0, 2.0]) >>> a.squared_distance(b) @@ -126,22 +274,22 @@ def squared_distance(self, other): >>> b.squared_distance(a) 30.0 """ - if type(other) == ndarray: - if other.ndim == 1: - result = 0.0 - j = 0 # index into our own array - for i in xrange(other.shape[0]): - if j < len(self.indices) and self.indices[j] == i: - diff = self.values[j] - other[i] - result += diff * diff - j += 1 - else: - result += other[i] * other[i] - return result - else: + if type(other) in (list, array.array, DenseVector, np.array, np.ndarray): + if type(other) is np.array and other.ndim != 1: raise Exception("Cannot call squared_distance with %d-dimensional array" % other.ndim) - else: + result = 0.0 + j = 0 # index into our own array + for i in xrange(len(other)): + if j < len(self.indices) and self.indices[j] == i: + diff = self.values[j] - other[i] + result += diff * diff + j += 1 + else: + result += other[i] * other[i] + return result + + elif type(other) is SparseVector: result = 0.0 i, j = 0, 0 while i < len(self.indices) and j < len(other.indices): @@ -163,16 +311,21 @@ def squared_distance(self, other): result += other.values[j] * other.values[j] j += 1 return result + else: + return self.squared_distance(_convert_to_vector(other)) def toArray(self): """ Returns a copy of this SparseVector as a 1-dimensional NumPy array. """ - arr = numpy.zeros(self.size) + arr = np.zeros((self.size,), dtype=np.float64) for i in xrange(self.indices.size): arr[self.indices[i]] = self.values[i] return arr + def __len__(self): + return self.size + def __str__(self): inds = "[" + ",".join([str(i) for i in self.indices]) + "]" vals = "[" + ",".join([str(v) for v in self.values]) + "]" @@ -198,8 +351,8 @@ def __eq__(self, other): return (isinstance(other, self.__class__) and other.size == self.size - and array_equal(other.indices, self.indices) - and array_equal(other.values, self.values)) + and other.indices == self.indices + and other.values == self.values) def __ne__(self, other): return not self.__eq__(other) @@ -242,9 +395,9 @@ def dense(elements): returns a NumPy array. >>> Vectors.dense([1, 2, 3]) - array([ 1., 2., 3.]) + DenseVector(array('d', [1.0, 2.0, 3.0])) """ - return array(elements, dtype=float64) + return DenseVector(elements) @staticmethod def stringify(vector): @@ -257,10 +410,39 @@ def stringify(vector): >>> Vectors.stringify(Vectors.dense([0.0, 1.0])) '[0.0,1.0]' """ - if type(vector) == SparseVector: - return str(vector) - else: - return "[" + ",".join([str(v) for v in vector]) + "]" + return str(vector) + + +class Matrix(object): + """ the Matrix """ + def __init__(self, nRow, nCol): + self.nRow = nRow + self.nCol = nCol + + def toArray(self): + raise NotImplementedError + + +class DenseMatrix(Matrix): + def __init__(self, nRow, nCol, values): + Matrix.__init__(self, nRow, nCol) + assert len(values) == nRow * nCol + self.values = values + + def __reduce__(self): + return DenseMatrix, (self.nRow, self.nCol, self.values) + + def toArray(self): + """ + Return an numpy.ndarray + + >>> arr = array.array('d', [float(i) for i in range(4)]) + >>> m = DenseMatrix(2, 2, arr) + >>> m.toArray() + array([[ 0., 1.], + [ 2., 3.]]) + """ + return np.ndarray((self.nRow, self.nCol), np.float64, buffer=self.values.tostring()) def _test(): diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index d53c95fd59c25..a787e4dea2c55 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -19,15 +19,32 @@ Python package for random data generation. """ +from functools import wraps from pyspark.rdd import RDD -from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector -from pyspark.serializers import NoOpSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer __all__ = ['RandomRDDs', ] +def serialize(f): + @wraps(f) + def func(sc, *a, **kw): + jrdd = f(sc, *a, **kw) + return RDD(sc._jvm.PythonRDD.javaToPython(jrdd), sc, + BatchedSerializer(PickleSerializer(), 1024)) + return func + + +def toArray(f): + @wraps(f) + def func(sc, *a, **kw): + rdd = f(sc, *a, **kw) + return rdd.map(lambda vec: vec.toArray()) + return func + + class RandomRDDs(object): """ Generator methods for creating RDDs comprised of i.i.d samples from @@ -35,6 +52,7 @@ class RandomRDDs(object): """ @staticmethod + @serialize def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -56,11 +74,10 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): >>> parts == sc.defaultParallelism True """ - jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) - return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) + return sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) @staticmethod + @serialize def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -80,11 +97,10 @@ def normalRDD(sc, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - 1.0) < 0.1 True """ - jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) - return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) + return sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) @staticmethod + @serialize def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson @@ -101,11 +117,11 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - sqrt(mean)) < 0.5 True """ - jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) - return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) + return sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) @staticmethod + @toArray + @serialize def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -120,12 +136,12 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 """ - jrdd = sc._jvm.PythonMLLibAPI() \ + return sc._jvm.PythonMLLibAPI() \ .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) - return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod + @toArray + @serialize def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -140,12 +156,12 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - 1.0) < 0.1 True """ - jrdd = sc._jvm.PythonMLLibAPI() \ + return sc._jvm.PythonMLLibAPI() \ .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) - return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod + @toArray + @serialize def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -163,10 +179,8 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - sqrt(mean)) < 0.5 True """ - jrdd = sc._jvm.PythonMLLibAPI() \ + return sc._jvm.PythonMLLibAPI() \ .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) - return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) def _test(): diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 2df23394da6f8..59c1c5ff0ced0 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,17 +16,25 @@ # from pyspark import SparkContext -from pyspark.mllib._common import \ - _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ - _serialize_tuple, RatingDeserializer +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.rdd import RDD __all__ = ['MatrixFactorizationModel', 'ALS'] +class Rating(object): + def __init__(self, user, product, rating): + self.user = int(user) + self.product = int(product) + self.rating = float(rating) + + def __reduce__(self): + return Rating, (self.user, self.product, self.rating) + + def __repr__(self): + return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating) + + class MatrixFactorizationModel(object): """A matrix factorisation model trained by regularized alternating @@ -39,7 +47,9 @@ class MatrixFactorizationModel(object): >>> model = ALS.trainImplicit(ratings, 1) >>> model.predict(2,2) is not None True + >>> testset = sc.parallelize([(1, 2), (1, 1)]) + >>> model = ALS.train(ratings, 1) >>> model.predictAll(testset).count() == 2 True """ @@ -54,34 +64,61 @@ def __del__(self): def predict(self, user, product): return self._java_model.predict(user, product) - def predictAll(self, usersProducts): - usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) - return RDD(self._java_model.predict(usersProductsJRDD._jrdd), - self._context, RatingDeserializer()) + def predictAll(self, user_product): + assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" + first = user_product.first() + if isinstance(first, list): + user_product = user_product.map(tuple) + first = tuple(first) + assert type(first) is tuple and len(first) == 2, \ + "user_product should be RDD of (user, product)" + if any(isinstance(x, str) for x in first): + user_product = user_product.map(lambda (u, p): (int(x), int(p))) + first = tuple(map(int, first)) + assert all(type(x) is int for x in first), "user and product in user_product shoul be int" + sc = self._context + tuplerdd = sc._jvm.SerDe.asTupleRDD(user_product._to_java_object_rdd().rdd()) + jresult = self._java_model.predict(tuplerdd).toJavaRDD() + return RDD(sc._jvm.PythonRDD.javaToPython(jresult), sc, + AutoBatchedSerializer(PickleSerializer())) class ALS(object): + @classmethod + def _prepare(cls, ratings): + assert isinstance(ratings, RDD), "ratings should be RDD" + first = ratings.first() + if not isinstance(first, Rating): + if isinstance(first, (tuple, list)): + ratings = ratings.map(lambda x: Rating(*x)) + else: + raise ValueError("rating should be RDD of Rating or tuple/list") + # serialize them by AutoBatchedSerializer before cache to reduce the + # objects overhead in JVM + cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache() + return cached._to_java_object_rdd() + @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): sc = ratings.context - ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) - mod = sc._jvm.PythonMLLibAPI().trainALSModel( - ratingBytes._jrdd, rank, iterations, lambda_, blocks) + jrating = cls._prepare(ratings) + mod = sc._jvm.PythonMLLibAPI().trainALSModel(jrating, rank, iterations, lambda_, blocks) return MatrixFactorizationModel(sc, mod) @classmethod def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): sc = ratings.context - ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + jrating = cls._prepare(ratings) mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel( - ratingBytes._jrdd, rank, iterations, lambda_, blocks, alpha) + jrating, rank, iterations, lambda_, blocks, alpha) return MatrixFactorizationModel(sc, mod) def _test(): import doctest - globs = globals().copy() + import pyspark.mllib.recommendation + globs = pyspark.mllib.recommendation.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index f572dcfb840b6..cbdbc09858013 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -15,12 +15,12 @@ # limitations under the License. # -from numpy import array, ndarray -from pyspark import SparkContext -from pyspark.mllib._common import _dot, _regression_train_wrapper, \ - _linear_predictor_typecheck, _have_scipy, _scipy_issparse -from pyspark.mllib.linalg import SparseVector, Vectors +import numpy as np +from numpy import array +from pyspark import SparkContext +from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -38,16 +38,16 @@ class LabeledPoint(object): def __init__(self, label, features): self.label = label - if (type(features) == ndarray or type(features) == SparseVector - or (_have_scipy and _scipy_issparse(features))): - self.features = features - elif type(features) == list: - self.features = array(features) - else: - raise TypeError("Expected NumPy array, list, SparseVector, or scipy.sparse matrix") + self.features = _convert_to_vector(features) + + def __reduce__(self): + return (LabeledPoint, (self.label, self.features)) def __str__(self): - return "(" + ",".join((str(self.label), Vectors.stringify(self.features))) + ")" + return "(" + ",".join((str(self.label), str(self.features))) + ")" + + def __repr__(self): + return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")" class LinearModel(object): @@ -55,7 +55,7 @@ class LinearModel(object): """A linear model that has a vector of coefficients and an intercept.""" def __init__(self, weights, intercept): - self._coeff = weights + self._coeff = _convert_to_vector(weights) self._intercept = intercept @property @@ -71,18 +71,19 @@ class LinearRegressionModelBase(LinearModel): """A linear regression model. - >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) - >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 + >>> lrmb = LinearRegressionModelBase(np.array([1.0, 2.0]), 0.1) + >>> abs(lrmb.predict(np.array([-1.03, 7.777])) - 14.624) < 1e-6 True >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 True """ def predict(self, x): - """Predict the value of the dependent variable given a vector x""" - """containing values for the independent variables.""" - _linear_predictor_typecheck(x, self._coeff) - return _dot(x, self._coeff) + self._intercept + """ + Predict the value of the dependent variable given a vector x + containing values for the independent variables. + """ + return self.weights.dot(x) + self.intercept class LinearRegressionModel(LinearRegressionModelBase): @@ -96,10 +97,10 @@ class LinearRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) - >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=np.array([1.0])) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True - >>> abs(lrm.predict(array([1.0])) - 1) < 0.5 + >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True @@ -117,11 +118,27 @@ class LinearRegressionModel(LinearRegressionModelBase): """ +# train_func should take two parameters, namely data and initial_weights, and +# return the result of a call to the appropriate JVM stub. +# _regression_train_wrapper is responsible for setup and error checking. +def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights): + initial_weights = initial_weights or [0.0] * len(data.first().features) + ser = PickleSerializer() + initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights))) + # use AutoBatchedSerializer before cache to reduce the memory + # overhead in JVM + cached = data._reserialize(AutoBatchedSerializer(ser)).cache() + ans = train_func(cached._to_java_object_rdd(), initial_bytes) + assert len(ans) == 2, "JVM call result had unexpected length" + weights = ser.loads(str(ans[0])) + return modelClass(weights, ans[1]) + + class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType=None, intercept=False): + initialWeights=None, regParam=1.0, regType="none", intercept=False): """ Train a linear regression model on the given data. @@ -146,11 +163,12 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, are activated or not). """ sc = data.context - if regType is None: - regType = "none" - train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) - return _regression_train_wrapper(sc, train_f, LinearRegressionModel, data, initialWeights) + + def train(jrdd, i): + return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( + jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) + + return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights) class LassoModel(LinearRegressionModelBase): @@ -166,9 +184,9 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(2.0, [3.0]) ... ] >>> lrm = LassoWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) - >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True - >>> abs(lrm.predict(array([1.0])) - 1) < 0.5 + >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True @@ -179,7 +197,7 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) - >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True @@ -193,9 +211,11 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): """Train a Lasso regression model on the given data.""" sc = data.context - train_f = lambda d, i: sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i) - return _regression_train_wrapper(sc, train_f, LassoModel, data, initialWeights) + + def train(jrdd, i): + return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD( + jrdd, iterations, step, regParam, miniBatchFraction, i) + return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights) class RidgeRegressionModel(LinearRegressionModelBase): @@ -211,9 +231,9 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(2.0, [3.0]) ... ] >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) - >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True - >>> abs(lrm.predict(array([1.0])) - 1) < 0.5 + >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True @@ -224,7 +244,7 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) - >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True @@ -238,9 +258,12 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): """Train a ridge regression model on the given data.""" sc = data.context - train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i) - return _regression_train_wrapper(sc, train_func, RidgeRegressionModel, data, initialWeights) + + def train(jrdd, i): + return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD( + jrdd, iterations, step, regParam, miniBatchFraction, i) + + return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights) def _test(): diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 8c726f171c978..b9de0909a6fb1 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,14 +19,26 @@ Python package for statistical functions in MLlib. """ -from pyspark.mllib._common import \ - _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ - _serialize_double, _deserialize_double_matrix, _deserialize_double_vector +from functools import wraps + +from pyspark import PickleSerializer __all__ = ['MultivariateStatisticalSummary', 'Statistics'] +def serialize(f): + ser = PickleSerializer() + + @wraps(f) + def func(self): + jvec = f(self) + bytes = self._sc._jvm.SerDe.dumps(jvec) + return ser.loads(str(bytes)).toArray() + + return func + + class MultivariateStatisticalSummary(object): """ @@ -44,33 +56,38 @@ def __init__(self, sc, java_summary): def __del__(self): self._sc._gateway.detach(self._java_summary) + @serialize def mean(self): - return _deserialize_double_vector(self._java_summary.mean()) + return self._java_summary.mean() + @serialize def variance(self): - return _deserialize_double_vector(self._java_summary.variance()) + return self._java_summary.variance() def count(self): return self._java_summary.count() + @serialize def numNonzeros(self): - return _deserialize_double_vector(self._java_summary.numNonzeros()) + return self._java_summary.numNonzeros() + @serialize def max(self): - return _deserialize_double_vector(self._java_summary.max()) + return self._java_summary.max() + @serialize def min(self): - return _deserialize_double_vector(self._java_summary.min()) + return self._java_summary.min() class Statistics(object): @staticmethod - def colStats(X): + def colStats(rdd): """ Computes column-wise summary statistics for the input RDD[Vector]. - >>> from linalg import Vectors + >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), ... Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8])]) @@ -88,9 +105,9 @@ def colStats(X): >>> cStats.min() array([ 2., 0., 0., -2.]) """ - sc = X.ctx - Xser = _get_unmangled_double_vector_rdd(X) - cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd) + sc = rdd.ctx + jrdd = rdd._to_java_object_rdd() + cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) return MultivariateStatisticalSummary(sc, cStats) @staticmethod @@ -117,7 +134,7 @@ def corr(x, y=None, method=None): >>> from math import isnan >>> isnan(Statistics.corr(x, zeros)) True - >>> from linalg import Vectors + >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) >>> pearsonCorr = Statistics.corr(rdd) @@ -144,18 +161,16 @@ def corr(x, y=None, method=None): # check if y is used to specify the method name instead. if type(y) == str: raise TypeError("Use 'method=' to specify method name.") + + jx = x._to_java_object_rdd() if not y: - try: - Xser = _get_unmangled_double_vector_rdd(x) - except TypeError: - raise TypeError("corr called on a single RDD not consisted of Vectors.") - resultMat = sc._jvm.PythonMLLibAPI().corr(Xser._jrdd, method) - return _deserialize_double_matrix(resultMat) + resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) + bytes = sc._jvm.SerDe.dumps(resultMat) + ser = PickleSerializer() + return ser.loads(str(bytes)).toArray() else: - xSer = _get_unmangled_rdd(x, _serialize_double) - ySer = _get_unmangled_rdd(y, _serialize_double) - result = sc._jvm.PythonMLLibAPI().corr(xSer._jrdd, ySer._jrdd, method) - return result + jy = y._to_java_object_rdd() + return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) def _test(): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 8a851bd35c0e8..f72e88ba6e2ba 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -20,6 +20,8 @@ """ import sys +import array as pyarray + from numpy import array, array_equal if sys.version_info[:2] <= (2, 6): @@ -27,9 +29,8 @@ else: import unittest -from pyspark.mllib._common import _convert_vector, _serialize_double_vector, \ - _deserialize_double_vector, _dot, _squared_distance -from pyspark.mllib.linalg import SparseVector +from pyspark.serializers import PickleSerializer +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.tests import PySparkTestCase @@ -42,39 +43,52 @@ # No SciPy, but that's okay, we'll skip those tests pass +ser = PickleSerializer() + + +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) -class VectorTests(unittest.TestCase): + +class VectorTests(PySparkTestCase): + + def _test_serialize(self, v): + jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(str(self.sc._jvm.SerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) def test_serialize(self): - sv = SparseVector(4, {1: 1, 3: 2}) - dv = array([1., 2., 3., 4.]) - lst = [1, 2, 3, 4] - self.assertTrue(sv is _convert_vector(sv)) - self.assertTrue(dv is _convert_vector(dv)) - self.assertTrue(array_equal(dv, _convert_vector(lst))) - self.assertEquals(sv, _deserialize_double_vector(_serialize_double_vector(sv))) - self.assertTrue(array_equal(dv, _deserialize_double_vector(_serialize_double_vector(dv)))) - self.assertTrue(array_equal(dv, _deserialize_double_vector(_serialize_double_vector(lst)))) + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2}) - dv = array([1., 2., 3., 4.]) - lst = [1, 2, 3, 4] + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) mat = array([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]) - self.assertEquals(10.0, _dot(sv, dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), _dot(sv, mat))) - self.assertEquals(30.0, _dot(dv, dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), _dot(dv, mat))) - self.assertEquals(30.0, _dot(lst, dv)) - self.assertTrue(array_equal(array([10., 20., 30., 40.]), _dot(lst, mat))) + self.assertEquals(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEquals(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEquals(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) - dv = array([1., 2., 3., 4.]) - lst = [4, 3, 2, 1] + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) self.assertEquals(15.0, _squared_distance(sv, dv)) self.assertEquals(25.0, _squared_distance(sv, lst)) self.assertEquals(20.0, _squared_distance(dv, lst)) @@ -198,41 +212,36 @@ def test_serialize(self): lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv, _convert_vector(lil)) - self.assertEquals(sv, _convert_vector(lil.tocsc())) - self.assertEquals(sv, _convert_vector(lil.tocoo())) - self.assertEquals(sv, _convert_vector(lil.tocsr())) - self.assertEquals(sv, _convert_vector(lil.todok())) - self.assertEquals(sv, _deserialize_double_vector(_serialize_double_vector(lil))) - self.assertEquals(sv, _deserialize_double_vector(_serialize_double_vector(lil.tocsc()))) - self.assertEquals(sv, _deserialize_double_vector(_serialize_double_vector(lil.tocsr()))) - self.assertEquals(sv, _deserialize_double_vector(_serialize_double_vector(lil.todok()))) + self.assertEquals(sv, _convert_to_vector(lil)) + self.assertEquals(sv, _convert_to_vector(lil.tocsc())) + self.assertEquals(sv, _convert_to_vector(lil.tocoo())) + self.assertEquals(sv, _convert_to_vector(lil.tocsr())) + self.assertEquals(sv, _convert_to_vector(lil.todok())) + + def serialize(l): + return ser.loads(ser.dumps(_convert_to_vector(l))) + self.assertEquals(sv, serialize(lil)) + self.assertEquals(sv, serialize(lil.tocsc())) + self.assertEquals(sv, serialize(lil.tocsr())) + self.assertEquals(sv, serialize(lil.todok())) def test_dot(self): from scipy.sparse import lil_matrix lil = lil_matrix((4, 1)) lil[1, 0] = 1 lil[3, 0] = 2 - dv = array([1., 2., 3., 4.]) - sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - mat = array([[1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.], - [1., 2., 3., 4.]]) - self.assertEquals(10.0, _dot(lil, dv)) - self.assertTrue(array_equal(array([3., 6., 9., 12.]), _dot(lil, mat))) + dv = DenseVector(array([1., 2., 3., 4.])) + self.assertEquals(10.0, dv.dot(lil)) def test_squared_distance(self): from scipy.sparse import lil_matrix lil = lil_matrix((4, 1)) lil[1, 0] = 3 lil[3, 0] = 2 - dv = array([1., 2., 3., 4.]) + dv = DenseVector(array([1., 2., 3., 4.])) sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEquals(15.0, _squared_distance(lil, dv)) - self.assertEquals(15.0, _squared_distance(lil, sv)) - self.assertEquals(15.0, _squared_distance(dv, lil)) - self.assertEquals(15.0, _squared_distance(sv, lil)) + self.assertEquals(15.0, dv.squared_distance(lil)) + self.assertEquals(15.0, sv.squared_distance(lil)) def scipy_matrix(self, size, values): """Create a column SciPy matrix from a dictionary of values""" diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5b13ab682bbfc..f59a818a6e74d 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -18,13 +18,9 @@ from py4j.java_collections import MapConverter from pyspark import SparkContext, RDD -from pyspark.mllib._common import \ - _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \ - _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \ - _deserialize_double +from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.mllib.linalg import Vector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint -from pyspark.serializers import NoOpSerializer - __all__ = ['DecisionTreeModel', 'DecisionTree'] @@ -55,21 +51,24 @@ def predict(self, x): :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ - pythonAPI = self._sc._jvm.PythonMLLibAPI() + SerDe = self._sc._jvm.SerDe + ser = PickleSerializer() if isinstance(x, RDD): # Bulk prediction - if x.count() == 0: + first = x.take(1) + if not first: return self._sc.parallelize([]) - dataBytes = _get_unmangled_double_vector_rdd(x, cache=False) - jSerializedPreds = \ - pythonAPI.predictDecisionTreeModel(self._java_model, - dataBytes._jrdd) - serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) - return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) + if not isinstance(first[0], Vector): + x = x.map(_convert_to_vector) + jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD() + jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred) + return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) + else: # Assume x is a single data point. - x_ = _serialize_double_vector(x) - return pythonAPI.predictDecisionTreeModel(self._java_model, x_) + bytes = bytearray(ser.dumps(_convert_to_vector(x))) + vec = self._sc._jvm.SerDe.loads(bytes) + return self._java_model.predict(vec) def numNodes(self): return self._java_model.numNodes() @@ -77,7 +76,7 @@ def numNodes(self): def depth(self): return self._java_model.depth() - def __str__(self): + def __repr__(self): return self._java_model.toString() @@ -90,52 +89,23 @@ class DecisionTree(object): EXPERIMENTAL: This is an experimental API. It will probably be modified for Spark v1.2. - Example usage: - - >>> from numpy import array - >>> import sys - >>> from pyspark.mllib.regression import LabeledPoint - >>> from pyspark.mllib.tree import DecisionTree - >>> from pyspark.mllib.linalg import SparseVector - >>> - >>> data = [ - ... LabeledPoint(0.0, [0.0]), - ... LabeledPoint(1.0, [1.0]), - ... LabeledPoint(1.0, [2.0]), - ... LabeledPoint(1.0, [3.0]) - ... ] - >>> categoricalFeaturesInfo = {} # no categorical features - >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2, - ... categoricalFeaturesInfo=categoricalFeaturesInfo) - >>> sys.stdout.write(model) - DecisionTreeModel classifier - If (feature 0 <= 0.5) - Predict: 0.0 - Else (feature 0 > 0.5) - Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True - >>> sparse_data = [ - ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), - ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), - ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), - ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) - ... ] - >>> - >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), - ... categoricalFeaturesInfo=categoricalFeaturesInfo) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True """ + @staticmethod + def _train(data, type, numClasses, categoricalFeaturesInfo, + impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): + first = data.first() + assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" + sc = data.context + jrdd = data._to_java_object_rdd() + cfiMap = MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + jrdd, type, numClasses, cfiMap, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + return DecisionTreeModel(sc, model) + @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, @@ -159,18 +129,34 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel + + Example usage: + + >>> from numpy import array + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(1.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) + >>> print model, # it already has newline + DecisionTreeModel classifier + If (feature 0 <= 0.5) + Predict: 0.0 + Else (feature 0 > 0.5) + Predict: 1.0 + >>> model.predict(array([1.0])) > 0 + True + >>> model.predict(array([0.0])) == 0 + True """ - sc = data.context - dataBytes = _get_unmangled_labeled_point_rdd(data) - categoricalFeaturesInfoJMap = \ - MapConverter().convert(categoricalFeaturesInfo, - sc._gateway._gateway_client) - model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, "classification", - numClasses, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - dataBytes.unpersist() - return DecisionTreeModel(sc, model) + return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @staticmethod def trainRegressor(data, categoricalFeaturesInfo, @@ -194,18 +180,33 @@ def trainRegressor(data, categoricalFeaturesInfo, the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel + + Example usage: + + >>> from numpy import array + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) + >>> model.predict(array([0.0, 1.0])) == 1 + True + >>> model.predict(array([0.0, 0.0])) == 0 + True + >>> model.predict(SparseVector(2, {1: 1.0})) == 1 + True + >>> model.predict(SparseVector(2, {1: 0.0})) == 0 + True """ - sc = data.context - dataBytes = _get_unmangled_labeled_point_rdd(data) - categoricalFeaturesInfoJMap = \ - MapConverter().convert(categoricalFeaturesInfo, - sc._gateway._gateway_client) - model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, "regression", - 0, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - dataBytes.unpersist() - return DecisionTreeModel(sc, model) + return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) def _test(): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 1c7b8c809ab5b..8233d4e81f1ca 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -18,11 +18,10 @@ import numpy as np import warnings -from pyspark.mllib.linalg import Vectors, SparseVector -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib._common import _convert_vector, _deserialize_labeled_point from pyspark.rdd import RDD -from pyspark.serializers import NoOpSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint class MLUtils(object): @@ -32,15 +31,12 @@ class MLUtils(object): """ @staticmethod - def _parse_libsvm_line(line, multiclass): - warnings.warn("deprecated", DeprecationWarning) - return _parse_libsvm_line(line) - - @staticmethod - def _parse_libsvm_line(line): + def _parse_libsvm_line(line, multiclass=None): """ Parses a line in LIBSVM format into (label, indices, values). """ + if multiclass is not None: + warnings.warn("deprecated", DeprecationWarning) items = line.split(None) label = float(items[0]) nnz = len(items) - 1 @@ -55,27 +51,20 @@ def _parse_libsvm_line(line): @staticmethod def _convert_labeled_point_to_libsvm(p): """Converts a LabeledPoint to a string in LIBSVM format.""" + assert isinstance(p, LabeledPoint) items = [str(p.label)] - v = _convert_vector(p.features) - if type(v) == np.ndarray: - for i in xrange(len(v)): - items.append(str(i + 1) + ":" + str(v[i])) - elif type(v) == SparseVector: + v = _convert_to_vector(p.features) + if isinstance(v, SparseVector): nnz = len(v.indices) for i in xrange(nnz): items.append(str(v.indices[i] + 1) + ":" + str(v.values[i])) else: - raise TypeError("_convert_labeled_point_to_libsvm needs either ndarray or SparseVector" - " but got " % type(v)) + for i in xrange(len(v)): + items.append(str(i + 1) + ":" + str(v[i])) return " ".join(items) @staticmethod - def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): - warnings.warn("deprecated", DeprecationWarning) - return loadLibSVMFile(sc, path, numFeatures, minPartitions) - - @staticmethod - def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): + def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): """ Loads labeled data in the LIBSVM format into an RDD of LabeledPoint. The LIBSVM format is a text-based format used by @@ -122,6 +111,8 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> print examples[2] (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ + if multiclass is not None: + warnings.warn("deprecated", DeprecationWarning) lines = sc.textFile(path, minPartitions) parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) @@ -182,9 +173,9 @@ def loadLabeledPoints(sc, path, minPartitions=None): (0.0,[1.01,2.02,3.03]) """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) - jSerialized = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) - serialized = RDD(jSerialized, sc, NoOpSerializer()) - return serialized.map(lambda bytes: _deserialize_labeled_point(bytearray(bytes))) + jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) + jpyrdd = sc._jvm.PythonRDD.javaToPython(jrdd) + return RDD(jpyrdd, sc, BatchedSerializer(PickleSerializer())) def _test(): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b43606b7304c5..8ef233bc80c5c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -34,7 +34,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, CompressedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -1927,10 +1927,10 @@ def _to_java_object_rdd(self): It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ - if not self._is_pickled(): - self = self._reserialize(BatchedSerializer(PickleSerializer(), 1024)) - batched = isinstance(self._jrdd_deserializer, BatchedSerializer) - return self.ctx._jvm.PythonRDD.pythonToJava(self._jrdd, batched) + rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \ + if not self._is_pickled() else self + is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer) + return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch) def countApprox(self, timeout, confidence=0.95): """ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 44ac5642836e0..2672da36c1f50 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -68,6 +68,7 @@ import types import collections import zlib +import itertools from pyspark import cloudpickle @@ -214,6 +215,41 @@ def __str__(self): return "BatchedSerializer<%s>" % str(self.serializer) +class AutoBatchedSerializer(BatchedSerializer): + """ + Choose the size of batch automatically based on the size of object + """ + + def __init__(self, serializer, bestSize=1 << 20): + BatchedSerializer.__init__(self, serializer, -1) + self.bestSize = bestSize + + def dump_stream(self, iterator, stream): + batch, best = 1, self.bestSize + iterator = iter(iterator) + while True: + vs = list(itertools.islice(iterator, batch)) + if not vs: + break + + bytes = self.serializer.dumps(vs) + write_int(len(bytes), stream) + stream.write(bytes) + + size = len(bytes) + if size < best: + batch *= 2 + elif size > best * 10 and batch > 1: + batch /= 2 + + def __eq__(self, other): + return (isinstance(other, AutoBatchedSerializer) and + other.serializer == self.serializer) + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) + + class CartesianDeserializer(FramedSerializer): """ diff --git a/python/run-tests b/python/run-tests index a67e5a99fbdcc..a7ec270c7da21 100755 --- a/python/run-tests +++ b/python/run-tests @@ -73,7 +73,6 @@ run_test "pyspark/serializers.py" unset PYSPARK_DOC_TEST run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" -run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/linalg.py" From 2c3cc7641d86fa5196406955325a042890f77563 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 19 Sep 2014 15:29:22 -0700 Subject: [PATCH 042/315] [SPARK-3501] [SQL] Fix the bug of Hive SimpleUDF creates unnecessary type cast When do the query like: ``` select datediff(cast(value as timestamp), cast('2002-03-21 00:00:00' as timestamp)) from src; ``` SparkSQL will raise exception: ``` [info] scala.MatchError: TimestampType (of class org.apache.spark.sql.catalyst.types.TimestampType$) [info] at org.apache.spark.sql.catalyst.expressions.Cast.castToTimestamp(Cast.scala:77) [info] at org.apache.spark.sql.catalyst.expressions.Cast.cast$lzycompute(Cast.scala:251) [info] at org.apache.spark.sql.catalyst.expressions.Cast.cast(Cast.scala:247) [info] at org.apache.spark.sql.catalyst.expressions.Cast.eval(Cast.scala:263) [info] at org.apache.spark.sql.catalyst.optimizer.ConstantFolding$$anonfun$apply$5$$anonfun$applyOrElse$2.applyOrElse(Optimizer.scala:217) [info] at org.apache.spark.sql.catalyst.optimizer.ConstantFolding$$anonfun$apply$5$$anonfun$applyOrElse$2.applyOrElse(Optimizer.scala:210) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:144) [info] at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4$$anonfun$apply$2.apply(TreeNode.scala:180) [info] at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) [info] at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) ``` Author: Cheng Hao Closes #2368 from chenghao-intel/cast_exception and squashes the following commits: 5c9c3a5 [Cheng Hao] make more clear code 49dfc50 [Cheng Hao] Add no-op for Cast and revert the position of SimplifyCasts b804abd [Cheng Hao] Add unit test to show the failure in identical data type casting 330a5c8 [Cheng Hao] Update Code based on comments b834ed4 [Cheng Hao] Fix bug of HiveSimpleUDF with unnecessary type cast which cause exception in constant folding --- .../apache/spark/sql/catalyst/expressions/Cast.scala | 1 + .../scala/org/apache/spark/sql/hive/hiveUdfs.scala | 3 ++- ...imestamp in UDF-0-66952a3949d7544716fd1a675498b1fa | 1 + .../spark/sql/hive/execution/HiveQuerySuite.scala | 11 ++++++++++- 4 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0ad2b30cf9c1f..0379275121bf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -245,6 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } private[this] lazy val cast: Any => Any = dataType match { + case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary case DecimalType => castToDecimal diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7d1ad53d8bdb3..7cda0dd302c86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -51,12 +51,13 @@ private[hive] abstract class HiveFunctionRegistry val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) - lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) + val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) HiveSimpleUdf( functionClassName, children.zip(expectedDataTypes).map { case (e, NullType) => e + case (e, t) if (e.dataType == t) => e case (e, t) => Cast(e, t) } ) diff --git a/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa b/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 8c8a8b124ac69..56bcd95eab4bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -142,16 +142,25 @@ class HiveQuerySuite extends HiveComparisonTest { setConf("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) setConf("spark.sql.dialect", "hiveql") - } test("Query expressed in HiveQL") { sql("FROM src SELECT key").collect() } + test("Query with constant folding the CAST") { + sql("SELECT CAST(CAST('123' AS binary) AS binary) FROM src LIMIT 1").collect() + } + createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT", "SELECT AVG(0), SUM(0), COUNT(null), COUNT(value) FROM src GROUP BY key") + createQueryTest("Cast Timestamp to Timestamp in UDF", + """ + | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) + | FROM src LIMIT 1 + """.stripMargin) + createQueryTest("Simple Average", "SELECT AVG(key) FROM src") From 5522151eb14f4208798901f5c090868edd8e8dde Mon Sep 17 00:00:00 2001 From: ravipesala Date: Fri, 19 Sep 2014 15:31:57 -0700 Subject: [PATCH 043/315] [SPARK-2594][SQL] Support CACHE TABLE AS SELECT ... This feature allows user to add cache table from the select query. Example : ```CACHE TABLE testCacheTable AS SELECT * FROM TEST_TABLE``` Spark takes this type of SQL as command and it does lazy caching just like ```SQLContext.cacheTable```, ```CACHE TABLE ``` does. It can be executed from both SQLContext and HiveContext. Recreated the pull request after rebasing with master.And fixed all the comments raised in previous pull requests. https://github.com/apache/spark/pull/2381 https://github.com/apache/spark/pull/2390 Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala Closes #2397 from ravipesala/SPARK-2594 and squashes the following commits: a5f0beb [ravipesala] Simplified the code as per Admin comment. 8059cd2 [ravipesala] Changed the behaviour from eager caching to lazy caching. d6e469d [ravipesala] Code review comments by Admin are handled. c18aa38 [ravipesala] Merge remote-tracking branch 'remotes/ravipesala/Add-Cache-table-as' into SPARK-2594 394d5ca [ravipesala] Changed style fb1759b [ravipesala] Updated as per Admin comments 8c9993c [ravipesala] Changed the style d8b37b2 [ravipesala] Updated as per the comments by Admin bc0bffc [ravipesala] Merge remote-tracking branch 'ravipesala/Add-Cache-table-as' into Add-Cache-table-as e3265d0 [ravipesala] Updated the code as per the comments by Admin in pull request. 724b9db [ravipesala] Changed style aaf5b59 [ravipesala] Added comment dc33895 [ravipesala] Updated parser to support add cache table command b5276b2 [ravipesala] Updated parser to support add cache table command eebc0c1 [ravipesala] Add CACHE TABLE AS SELECT ... 6758f80 [ravipesala] Changed style 7459ce3 [ravipesala] Added comment 13c8e27 [ravipesala] Updated parser to support add cache table command 4e858d8 [ravipesala] Updated parser to support add cache table command b803fc8 [ravipesala] Add CACHE TABLE AS SELECT ... --- .../apache/spark/sql/catalyst/SqlParser.scala | 14 +++++++-- .../sql/catalyst/plans/logical/commands.scala | 5 ++++ .../spark/sql/execution/SparkStrategies.scala | 2 ++ .../apache/spark/sql/execution/commands.scala | 18 +++++++++++ .../apache/spark/sql/CachedTableSuite.scala | 13 ++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 30 ++++++++++++------- 6 files changed, 69 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ca69531c69a77..862f78702c4e6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -151,7 +151,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache + | insert | cache | unCache ) protected lazy val select: Parser[LogicalPlan] = @@ -183,9 +183,17 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } protected lazy val cache: Parser[LogicalPlan] = - (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ { - case doCache ~ _ ~ tableName => CacheCommand(tableName, doCache) + CACHE ~ TABLE ~> ident ~ opt(AS ~> select) <~ opt(";") ^^ { + case tableName ~ None => + CacheCommand(tableName, true) + case tableName ~ Some(plan) => + CacheTableAsSelectCommand(tableName, plan) } + + protected lazy val unCache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { + case tableName => CacheCommand(tableName, false) + } protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index a01809c1fc5e2..8366639fa0e8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -75,3 +75,8 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } + +/** + * Returned for the "CACHE TABLE tableName AS SELECT .." command. + */ +case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7943d6e1b6fb5..45687d960404c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -305,6 +305,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheCommand(tableName, cache) => Seq(execution.CacheCommand(tableName, cache)(context)) + case logical.CacheTableAsSelectCommand(tableName, plan) => + Seq(execution.CacheTableAsSelectCommand(tableName, plan)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 94543fc95b470..c2f48a902a3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -166,3 +166,21 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( child.output.map(field => Row(field.name, field.dataType.toString, null)) } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CacheTableAsSelectCommand(tableName: String, logicalPlan: LogicalPlan) + extends LeafNode with Command { + + override protected[sql] lazy val sideEffectResult = { + import sqlContext._ + logicalPlan.registerTempTable(tableName) + cacheTable(tableName) + Seq.empty[Row] + } + + override def output: Seq[Attribute] = Seq.empty + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index befef46d93973..591592841e9fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -119,4 +119,17 @@ class CachedTableSuite extends QueryTest { } assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") } + + test("CACHE TABLE tableName AS SELECT Star Table") { + TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() + assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + TestSQLContext.uncacheTable("testCacheTable") + } + + test("'CACHE TABLE tableName AS SELECT ..'") { + TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + TestSQLContext.uncacheTable("testCacheTable") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 21ecf17028dbc..0aa6292c0184e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -229,7 +229,12 @@ private[hive] object HiveQl { SetCommand(Some(key), Some(value)) } } else if (sql.trim.toLowerCase.startsWith("cache table")) { - CacheCommand(sql.trim.drop(12).trim, true) + sql.trim.drop(12).trim.split(" ").toSeq match { + case Seq(tableName) => + CacheCommand(tableName, true) + case Seq(tableName, _, select @ _*) => + CacheTableAsSelectCommand(tableName, createPlan(select.mkString(" ").trim)) + } } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { @@ -243,15 +248,7 @@ private[hive] object HiveQl { } else if (sql.trim.startsWith("!")) { ShellCommand(sql.drop(1)) } else { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { - NativeCommand(sql) - } else { - nodeToPlan(tree) match { - case NativePlaceholder => NativeCommand(sql) - case other => other - } - } + createPlan(sql) } } catch { case e: Exception => throw new ParseException(sql, e) @@ -262,6 +259,19 @@ private[hive] object HiveQl { """.stripMargin) } } + + /** Creates LogicalPlan for a given HiveQL string. */ + def createPlan(sql: String) = { + val tree = getAst(sql) + if (nativeCommands contains tree.getText) { + NativeCommand(sql) + } else { + nodeToPlan(tree) match { + case NativePlaceholder => NativeCommand(sql) + case other => other + } + } + } def parseDdl(ddl: String): Seq[Attribute] = { val tree = From a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Sep 2014 15:33:42 -0700 Subject: [PATCH 044/315] [SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row Fix the issue when applySchema() to an RDD of Row. Also add type mapping for BinaryType. Author: Davies Liu Closes #2448 from davies/row and squashes the following commits: dd220cf [Davies Liu] fix test 3f3f188 [Davies Liu] add more test f559746 [Davies Liu] add tests, fix serialization 9688fd2 [Davies Liu] support applySchema to RDD of Row --- python/pyspark/sql.py | 13 ++++++++++--- python/pyspark/tests.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 42a9920f10e6f..653195ea438cf 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -440,6 +440,7 @@ def _parse_datatype_string(datatype_string): float: DoubleType, str: StringType, unicode: StringType, + bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.datetime: TimestampType, datetime.date: TimestampType, @@ -690,11 +691,12 @@ def _infer_schema_type(obj, dataType): ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (long,), + LongType: (int, long), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), + BinaryType: (bytearray,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -728,9 +730,9 @@ def _verify_type(obj, dataType): return _type = type(dataType) - if _type not in _acceptable_types: - return + assert _type in _acceptable_types, "unkown datatype: %s" % dataType + # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept abject in type %s" % (dataType, type(obj))) @@ -1121,6 +1123,11 @@ def applySchema(self, rdd, schema): # take the first few rows to verify schema rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + for row in rows: _verify_type(row, schema) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7301966e48045..a94eb0f429e0a 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -45,7 +45,7 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType +from pyspark.sql import SQLContext, IntegerType, Row from pyspark import shuffle _have_scipy = False @@ -659,6 +659,15 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_apply_schema_to_row(self): + srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) + self.assertEqual(srdd.collect(), srdd2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) + self.assertEqual(10, srdd3.count()) + class TestIO(PySparkTestCase): From 3b9cd13ebc108c7c6d518a760333cd992667126c Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 19 Sep 2014 15:34:48 -0700 Subject: [PATCH 045/315] SPARK-3605. Fix typo in SchemaRDD. Author: Sandy Ryza Closes #2460 from sryza/sandy-spark-3605 and squashes the following commits: 09d940b [Sandy Ryza] SPARK-3605. Fix typo in SchemaRDD. --- sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3bc5dce095511..3b873f7c62cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -45,9 +45,8 @@ import org.apache.spark.api.java.JavaRDD * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. * * A `SchemaRDD` can also be created by loading data in from external sources. - * Examples are loading data from Parquet files by using by using the - * `parquetFile` method on [[SQLContext]], and loading JSON datasets - * by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. + * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]] + * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. * * == SQL Queries == * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once From ba68a51c407197d478b330403af8fe24a176bef3 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 19 Sep 2014 15:39:31 -0700 Subject: [PATCH 046/315] [SPARK-3485][SQL] Use GenericUDFUtils.ConversionHelper for Simple UDF type conversions This is just another solution to SPARK-3485, in addition to PR #2355 In this patch, we will use ConventionHelper and FunctionRegistry to invoke a simple udf evaluation, which rely more on hive, but much cleaner and safer. We can discuss which one is better. Author: Daoyuan Wang Closes #2407 from adrian-wang/simpleudf and squashes the following commits: 15762d2 [Daoyuan Wang] add posmod test which would fail the test but now ok 0d69eb4 [Daoyuan Wang] another way to pass to hive simple udf --- .../execution/HiveCompatibilitySuite.scala | 1 + .../org/apache/spark/sql/hive/hiveUdfs.scala | 55 ++++++------------- 2 files changed, 17 insertions(+), 39 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ab487d673e813..556c984ad392b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -801,6 +801,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_or", "udf_parse_url", "udf_PI", + "udf_pmod", "udf_positive", "udf_pow", "udf_power", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7cda0dd302c86..5a0e6c5cc1bba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper + import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.common.`type`.HiveDecimal @@ -105,52 +107,27 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) @transient - lazy val dataType = javaClassToDataType(method.getReturnType) + protected lazy val arguments = children.map(c => toInspector(c.dataType)).toArray - protected lazy val wrappers: Array[(Any) => AnyRef] = method.getParameterTypes.map { argClass => - val primitiveClasses = Seq( - Integer.TYPE, classOf[java.lang.Integer], classOf[java.lang.String], java.lang.Double.TYPE, - classOf[java.lang.Double], java.lang.Long.TYPE, classOf[java.lang.Long], - classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte], - classOf[java.sql.Timestamp] - ) - val matchingConstructor = argClass.getConstructors.find { c => - c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) - } + // Create parameter converters + @transient + protected lazy val conversionHelper = new ConversionHelper(method, arguments) - matchingConstructor match { - case Some(constructor) => - (a: Any) => { - logDebug( - s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} $constructor.") - // We must make sure that primitives get boxed java style. - if (a == null) { - null - } else { - constructor.newInstance(a match { - case i: Int => i: java.lang.Integer - case bd: BigDecimal => new HiveDecimal(bd.underlying()) - case other: AnyRef => other - }).asInstanceOf[AnyRef] - } - } - case None => - (a: Any) => a match { - case wrapper => wrap(wrapper) - } - } + @transient + lazy val dataType = javaClassToDataType(method.getReturnType) + + def catalystToHive(value: Any): Object = value match { + // TODO need more types here? or can we use wrap() + case bd: BigDecimal => new HiveDecimal(bd.underlying()) + case d => d.asInstanceOf[Object] } // TODO: Finish input output types. override def eval(input: Row): Any = { - val evaluatedChildren = children.map(_.eval(input)) - // Wrap the function arguments in the expected types. - val args = evaluatedChildren.zip(wrappers).map { - case (arg, wrapper) => wrapper(arg) - } + val evaluatedChildren = children.map(c => catalystToHive(c.eval(input))) - // Invoke the udf and unwrap the result. - unwrap(method.invoke(function, args: _*)) + unwrap(FunctionRegistry.invoke(method, function, conversionHelper + .convertIfNecessary(evaluatedChildren: _*): _*)) } } From 99b06b6fd2d79403ef4307ac6f3fa84176e7a622 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Fri, 19 Sep 2014 15:44:47 -0700 Subject: [PATCH 047/315] [Build] Fix passing of args to sbt Simple mistake, simple fix: ```shell args="arg1 arg2 arg3" sbt $args # sbt sees 3 arguments sbt "$args" # sbt sees 1 argument ``` Should fix the problems we are seeing [here](https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/694/AMPLAB_JENKINS_BUILD_PROFILE=hadoop1.0,label=centos/console), for example. Author: Nicholas Chammas Closes #2462 from nchammas/fix-sbt-master-build and squashes the following commits: 4500c86 [Nicholas Chammas] warn about quoting 10018a6 [Nicholas Chammas] Revert "test hadoop1 build" 7d5356c [Nicholas Chammas] Revert "re-add bad quoting for testing" 061600c [Nicholas Chammas] re-add bad quoting for testing b2de56c [Nicholas Chammas] test hadoop1 build 43fb854 [Nicholas Chammas] unquote profile args --- dev/run-tests | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dev/run-tests b/dev/run-tests index 5f6df17b509a3..c3d8f49cdd993 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -127,6 +127,8 @@ echo "=========================================================================" # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc #+ to quit or retry. This echo is there to make it not block. + # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a + #+ single argument! # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? echo -e "q\n" \ @@ -159,10 +161,13 @@ echo "=========================================================================" # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc #+ to quit or retry. This echo is there to make it not block. + # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a + #+ single argument! + #+ "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array. # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? echo -e "q\n" \ - | sbt/sbt "$SBT_MAVEN_PROFILES_ARGS" "${SBT_MAVEN_TEST_ARGS[@]}" \ + | sbt/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } From 8af2370619a8a6bb1af7df43b8329ab319348ad8 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 19 Sep 2014 16:02:38 -0700 Subject: [PATCH 048/315] [Docs] Fix outdated docs for standalone cluster This is now supported! Author: andrewor14 Author: Andrew Or Closes #2461 from andrewor14/document-standalone-cluster and squashes the following commits: 85c8b9e [andrewor14] Wording change per Patrick 35e30ee [Andrew Or] Fix outdated docs for standalone cluster --- docs/spark-standalone.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 99a8e43a6b489..29b5491861bf3 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -248,8 +248,10 @@ You can also pass an option `--cores ` to control the number of cores The [`spark-submit` script](submitting-applications.html) provides the most straightforward way to submit a compiled Spark application to the cluster. For standalone clusters, Spark currently -only supports deploying the driver inside the client process that is submitting the application -(`client` deploy mode). +supports two deploy modes. In `client` mode, the driver is launched in the same process as the +client that submits the application. In `cluster` mode, however, the driver is launched from one +of the Worker processes inside the cluster, and the client process exits as soon as it fulfills +its responsibility of submitting the application without waiting for the application to finish. If your application is launched through Spark submit, then the application jar is automatically distributed to all worker nodes. For any additional jars that your application depends on, you From 78d4220fa0bf2f9ee663e34bbf3544a5313b02f0 Mon Sep 17 00:00:00 2001 From: Vida Ha Date: Sat, 20 Sep 2014 01:24:49 -0700 Subject: [PATCH 049/315] SPARK-3608 Break if the instance tag naming succeeds Author: Vida Ha Closes #2466 from vidaha/vida/spark-3608 and squashes the following commits: 9509776 [Vida Ha] Break if the instance tag naming succeeds --- ec2/spark_ec2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index abac71eaca595..fbeccd89b43b3 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -505,6 +505,7 @@ def tag_instance(instance, name): for i in range(0, 5): try: instance.add_tag(key='Name', value=name) + break except: print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) if (i == 5): From c32c8538efca2124924920614e4dbe7ce90938f4 Mon Sep 17 00:00:00 2001 From: "Santiago M. Mola" Date: Sat, 20 Sep 2014 15:05:03 -0700 Subject: [PATCH 050/315] Fix Java example in Streaming Programming Guide "val conf" was used instead of "SparkConf conf" in Java snippet. Author: Santiago M. Mola Closes #2472 from smola/patch-1 and squashes the following commits: 5bfeb9b [Santiago M. Mola] Fix Java example in Streaming Programming Guide --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 41f170580f452..5c21e912ea160 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -138,7 +138,7 @@ import org.apache.spark.streaming.api.java.*; import scala.Tuple2; // Create a local StreamingContext with two working thread and batch interval of 1 second -val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") +SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") JavaStreamingContext jssc = new JavaStreamingContext(conf, new Duration(1000)) {% endhighlight %} From 5f8833c672ab64aa5886a8239ae2ff2a8ea42363 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 20 Sep 2014 15:09:35 -0700 Subject: [PATCH 051/315] [PySpark] remove unnecessary use of numSlices from pyspark tests Author: Matthew Farrellee Closes #2467 from mattf/master-pyspark-remove-numslices-from-tests and squashes the following commits: c49a87b [Matthew Farrellee] [PySpark] remove unnecessary use of numSlices from pyspark tests --- python/pyspark/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a94eb0f429e0a..1b8afb763b26a 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1107,7 +1107,7 @@ def test_reserialization(self): def test_unbatched_save_and_read(self): basepath = self.tempdir.name ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei, numSlices=len(ei)).saveAsSequenceFile( + self.sc.parallelize(ei, len(ei)).saveAsSequenceFile( basepath + "/unbatched/") unbatched_sequence = sorted(self.sc.sequenceFile( @@ -1153,7 +1153,7 @@ def test_malformed_RDD(self): basepath = self.tempdir.name # non-batch-serialized RDD[[(K, V)]] should be rejected data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]] - rdd = self.sc.parallelize(data, numSlices=len(data)) + rdd = self.sc.parallelize(data, len(data)) self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( basepath + "/malformed/sequence")) From 7c8ad1c0838762f5b632f683834c88a711aef4dd Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sat, 20 Sep 2014 16:03:17 -0700 Subject: [PATCH 052/315] SPARK-3574. Shuffle finish time always reported as -1 The included test waits 100 ms after job completion for task completion events to come in so it can verify they have reasonable finish times. Does anyone know a better way to wait on listener events that are expected to come in? Author: Sandy Ryza Closes #2440 from sryza/sandy-spark-3574 and squashes the following commits: c81439b [Sandy Ryza] Fix test failure b340956 [Sandy Ryza] SPARK-3574. Remove shuffleFinishTime metric --- .../main/scala/org/apache/spark/executor/TaskMetrics.scala | 6 ------ .../main/scala/org/apache/spark/scheduler/JobLogger.scala | 1 - .../src/main/scala/org/apache/spark/util/JsonProtocol.scala | 2 -- .../scala/org/apache/spark/util/JsonProtocolSuite.scala | 3 --- 4 files changed, 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 99a88c13456df..3e49b6235aff3 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -137,7 +137,6 @@ class TaskMetrics extends Serializable { merged.localBlocksFetched += depMetrics.localBlocksFetched merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched merged.remoteBytesRead += depMetrics.remoteBytesRead - merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime) } _shuffleReadMetrics = Some(merged) } @@ -177,11 +176,6 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ @DeveloperApi class ShuffleReadMetrics extends Serializable { - /** - * Absolute time when this task finished reading shuffle data - */ - var shuffleFinishTime: Long = -1 - /** * Number of blocks fetched in this shuffle by this task (remote or local) */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 4d6b5c81883b6..ceb434feb6ca1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -171,7 +171,6 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener } val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { case Some(metrics) => - " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c4dddb2d1037e..6a48f673c4e78 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -255,7 +255,6 @@ private[spark] object JsonProtocol { } def shuffleReadMetricsToJson(shuffleReadMetrics: ShuffleReadMetrics): JValue = { - ("Shuffle Finish Time" -> shuffleReadMetrics.shuffleFinishTime) ~ ("Remote Blocks Fetched" -> shuffleReadMetrics.remoteBlocksFetched) ~ ("Local Blocks Fetched" -> shuffleReadMetrics.localBlocksFetched) ~ ("Fetch Wait Time" -> shuffleReadMetrics.fetchWaitTime) ~ @@ -590,7 +589,6 @@ private[spark] object JsonProtocol { def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = { val metrics = new ShuffleReadMetrics - metrics.shuffleFinishTime = (json \ "Shuffle Finish Time").extract[Long] metrics.remoteBlocksFetched = (json \ "Remote Blocks Fetched").extract[Int] metrics.localBlocksFetched = (json \ "Local Blocks Fetched").extract[Int] metrics.fetchWaitTime = (json \ "Fetch Wait Time").extract[Long] diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 2b45d8b695853..f1f88c5fd3634 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -356,7 +356,6 @@ class JsonProtocolSuite extends FunSuite { } private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) { - assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime) assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched) assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched) assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime) @@ -568,7 +567,6 @@ class JsonProtocolSuite extends FunSuite { t.inputMetrics = Some(inputMetrics) } else { val sr = new ShuffleReadMetrics - sr.shuffleFinishTime = b + c sr.remoteBytesRead = b + d sr.localBlocksFetched = e sr.fetchWaitTime = a + d @@ -806,7 +804,6 @@ class JsonProtocolSuite extends FunSuite { | "Memory Bytes Spilled": 800, | "Disk Bytes Spilled": 0, | "Shuffle Read Metrics": { - | "Shuffle Finish Time": 900, | "Remote Blocks Fetched": 800, | "Local Blocks Fetched": 700, | "Fetch Wait Time": 900, From 7f54580c4503d8b6bfcf7d4cbc83b83458140926 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 20 Sep 2014 16:30:49 -0700 Subject: [PATCH 053/315] [SPARK-3609][SQL] Adds sizeInBytes statistics for Limit operator when all output attributes are of native data types This helps to replace shuffled hash joins with broadcast hash joins in some cases. Author: Cheng Lian Closes #2468 from liancheng/more-stats and squashes the following commits: 32687dc [Cheng Lian] Moved the test case to PlannerSuite 5595a91 [Cheng Lian] Removes debugging code 73faf69 [Cheng Lian] Test case for auto choosing broadcast hash join f30fe1d [Cheng Lian] Adds sizeInBytes estimation for Limit when all output types are native types --- .../plans/logical/basicOperators.scala | 11 ++++++++++ .../spark/sql/catalyst/types/dataTypes.scala | 10 ++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 9 +++++---- .../spark/sql/execution/PlannerSuite.scala | 20 ++++++++++++++++++- 4 files changed, 45 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5d10754c7b028..8e8259cae6670 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -148,6 +148,17 @@ case class Aggregate( case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output + + override lazy val statistics: Statistics = + if (output.forall(_.dataType.isInstanceOf[NativeType])) { + val limit = limitExpr.eval(null).asInstanceOf[Int] + val sizeInBytes = (limit: Long) * output.map { a => + NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType]) + }.sum + Statistics(sizeInBytes = sizeInBytes) + } else { + Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) + } } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 49520b7678e90..e3050e5397937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -122,6 +122,16 @@ object NativeType { IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) + + val defaultSizeOf: Map[NativeType, Int] = Map( + IntegerType -> 4, + BooleanType -> 1, + LongType -> 8, + DoubleType -> 8, + FloatType -> 4, + ShortType -> 2, + ByteType -> 1, + StringType -> 4096) } trait PrimitiveType extends DataType { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 67563b6c55f4b..15f6bcef93886 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin} import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone @@ -649,24 +650,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { (3, null) :: (4, 2147483644) :: Nil) } - + test("SPARK-3423 BETWEEN") { checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 5 and 7"), Seq((5, "5"), (6, "6"), (7, "7")) ) - + checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 7 and 7"), Seq((7, "7")) ) - + checkAnswer( sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), Seq() ) } - + test("cast boolean to string") { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 37d64f0de7bab..bfbf431a11913 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution +import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ @@ -57,4 +57,22 @@ class PlannerSuite extends FunSuite { val planned = HashAggregation(query) assert(planned.nonEmpty) } + + test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { + val origThreshold = autoBroadcastJoinThreshold + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) + + // Using a threshold that is definitely larger than the small testing table (b) below + val a = testData.as('a) + val b = testData.limit(3).as('b) + val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + } } From 293ce85145d7a37f7cb329831cbf921be571c2f5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 20 Sep 2014 16:41:14 -0700 Subject: [PATCH 054/315] [SPARK-3414][SQL] Replace LowerCaseSchema with Resolver **This PR introduces a subtle change in semantics for HiveContext when using the results in Python or Scala. Specifically, while resolution remains case insensitive, it is now case preserving.** _This PR is a follow up to #2293 (and to a lesser extent #2262 #2334)._ In #2293 the catalog was changed to store analyzed logical plans instead of unresolved ones. While this change fixed the reported bug (which was caused by yet another instance of us forgetting to put in a `LowerCaseSchema` operator) it had the consequence of breaking assumptions made by `MultiInstanceRelation`. Specifically, we can't replace swap out leaf operators in a tree without rewriting changed expression ids (which happens when you self join the same RDD that has been registered as a temp table). In this PR, I instead remove the need to insert `LowerCaseSchema` operators at all, by moving the concern of matching up identifiers completely into analysis. Doing so allows the test cases from both #2293 and #2262 to pass at the same time (and likely fixes a slew of other "unknown unknown" bugs). While it is rolled back in this PR, storing the analyzed plan might actually be a good idea. For instance, it is kind of confusing if you register a temporary table, change the case sensitivity of resolution and now you can't query that table anymore. This can be addressed in a follow up PR. Follow-ups: - Configurable case sensitivity - Consider storing analyzed plans for temp tables Author: Michael Armbrust Closes #2382 from marmbrus/lowercase and squashes the following commits: c21171e [Michael Armbrust] Ensure the resolver is used for field lookups and ensure that case insensitive resolution is still case preserving. d4320f1 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into lowercase 2de881e [Michael Armbrust] Address comments. 219805a [Michael Armbrust] style 5b93711 [Michael Armbrust] Replace LowerCaseSchema with Resolver. --- .../sql/catalyst/analysis/Analyzer.scala | 38 +++------ .../spark/sql/catalyst/analysis/package.scala | 12 ++- .../sql/catalyst/analysis/unresolved.scala | 6 +- .../expressions/namedExpressions.scala | 10 ++- .../catalyst/plans/logical/LogicalPlan.scala | 77 ++++++++++++++++--- .../plans/logical/basicOperators.scala | 26 ------- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 2 - .../apache/spark/sql/hive/HiveContext.scala | 10 +-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 6 +- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../hive/execution/CreateTableAsSelect.scala | 4 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 7 +- ...nsitive-0-98b2e34c9134208e9fe7c62d33010005 | 1 + .../hive/execution/HiveResolutionSuite.scala | 19 +++-- 15 files changed, 125 insertions(+), 99 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/database.table table.attr case insensitive-0-98b2e34c9134208e9fe7c62d33010005 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 574d96d92942b..71810b798bd04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -37,6 +37,8 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean) extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { + val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution + // TODO: pass this in as a parameter. val fixedPoint = FixedPoint(100) @@ -48,8 +50,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool lazy val batches: Seq[Batch] = Seq( Batch("MultiInstanceRelations", Once, NewRelationInstances), - Batch("CaseInsensitiveAttributeReferences", Once, - (if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*), Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: @@ -98,23 +98,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } - /** - * Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase. - */ - object LowercaseAttributeReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case UnresolvedRelation(databaseName, name, alias) => - UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase)) - case Subquery(alias, child) => Subquery(alias.toLowerCase, child) - case q: LogicalPlan => q transformExpressions { - case s: Star => s.copy(table = s.table.map(_.toLowerCase)) - case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase) - case Alias(c, name) => Alias(c, name.toLowerCase)() - case GetField(c, name) => GetField(c, name.toLowerCase) - } - } - } - /** * Replaces [[UnresolvedAttribute]]s with concrete * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's @@ -127,7 +110,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolveChildren(name).getOrElse(u) + val result = q.resolveChildren(name, resolver).getOrElse(u) logDebug(s"Resolving $u to $result") result } @@ -144,7 +127,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolveChildren) + val resolved = unresolved.flatMap(child.resolve(_, resolver)) val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) val missingInProject = requiredAttributes -- p.output @@ -154,6 +137,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Sort(ordering, Project(projectList ++ missingInProject, child))) } else { + logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => @@ -165,7 +149,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ) logDebug(s"Grouping expressions: $groupingRelation") - val resolved = unresolved.flatMap(groupingRelation.resolve) + val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver)) val missingInAggs = resolved.filterNot(a.outputSet.contains) logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") if (missingInAggs.nonEmpty) { @@ -258,14 +242,14 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case p @ Project(projectList, child) if containsStar(projectList) => Project( projectList.flatMap { - case s: Star => s.expand(child.output) + case s: Star => s.expand(child.output, resolver) case o => o :: Nil }, child) case t: ScriptTransformation if containsStar(t.input) => t.copy( input = t.input.flatMap { - case s: Star => s.expand(t.child.output) + case s: Star => s.expand(t.child.output, resolver) case o => o :: Nil } ) @@ -273,7 +257,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case a: Aggregate if containsStar(a.aggregateExpressions) => a.copy( aggregateExpressions = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child.output) + case s: Star => s.expand(a.child.output, resolver) case o => o :: Nil } ) @@ -290,13 +274,11 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool /** * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are * only required to provide scoping information for attributes and can be removed once analysis is - * complete. Similarly, this node also removes - * [[catalyst.plans.logical.LowerCaseSchema LowerCaseSchema]] operators. + * complete. */ object EliminateAnalysisOperators extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Subquery(_, child) => child - case LowerCaseSchema(child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 9f37ca904ffeb..3f672a3e0fd91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -22,4 +22,14 @@ package org.apache.spark.sql.catalyst * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s * into fully typed objects using information in a schema [[Catalog]]. */ -package object analysis +package object analysis { + + /** + * Responsible for resolving which identifiers refer to the same entity. For example, by using + * case insensitive equality. + */ + type Resolver = (String, String) => Boolean + + val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b) + val caseSensitiveResolution = (a: String, b: String) => a == b +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a2c61c65487cb..67570a6f73c36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -54,6 +54,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def newInstance = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this + override def withName(newName: String) = UnresolvedAttribute(name) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -97,13 +98,14 @@ case class Star( override def newInstance = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this + override def withName(newName: String) = this - def expand(input: Seq[Attribute]): Seq[NamedExpression] = { + def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { // If there is no table specified, use all input attributes. case None => input // If there is a table, pick out attributes that are part of this table. - case Some(t) => input.filter(_.qualifiers contains t) + case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) } val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { case (n: NamedExpression, _) => n diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7c4b9d4847e26..59fb0311a9c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -59,6 +59,7 @@ abstract class Attribute extends NamedExpression { def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute + def withName(newName: String): Attribute def toAttribute = this def newInstance: Attribute @@ -86,7 +87,6 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable - override def toAttribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) @@ -144,6 +144,14 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea } } + override def withName(newName: String): AttributeReference = { + if (name == newName) { + this + } else { + AttributeReference(newName, dataType, nullable)(exprId, qualifiers) + } + } + /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ede431ad4ab27..28d863e58beca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees -abstract class LogicalPlan extends QueryPlan[LogicalPlan] { +abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { self: Product => /** @@ -75,20 +77,25 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolveChildren(name: String): Option[NamedExpression] = - resolve(name, children.flatMap(_.output)) + def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] = + resolve(name, children.flatMap(_.output), resolver) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String): Option[NamedExpression] = - resolve(name, output) + def resolve(name: String, resolver: Resolver): Option[NamedExpression] = + resolve(name, output, resolver) /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { + protected def resolve( + name: String, + input: Seq[Attribute], + resolver: Resolver): Option[NamedExpression] = { + val parts = name.split("\\.") + // Collect all attributes that are output by this nodes children where either the first part // matches the name or where the first part matches the scope and the second part matches the // name. Return these matches along with any remaining parts, which represent dotted access to @@ -96,21 +103,69 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { val options = input.flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. val remainingParts = - if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts - if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil + if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) { + parts.drop(1) + } else { + parts + } + + if (resolver(option.name, remainingParts.head)) { + // Preserve the case of the user's attribute reference. + (option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil + } else { + Nil + } } options.distinct match { - case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. + // One match, no nested fields, use it. + case Seq((a, Nil)) => Some(a) + // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) - case Seq() => None // No matches. + val aliased = + Alias( + resolveNesting(nestedFields, a, resolver), + nestedFields.last)() // Preserve the case of the user's field access. + Some(aliased) + + // No matches. + case Seq() => + logTrace(s"Could not find $name in ${input.mkString(", ")}") + None + + // More than one match. case ambiguousReferences => throw new TreeNodeException( this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") } } + + /** + * Given a list of successive nested field accesses, and a based expression, attempt to resolve + * the actual field lookups on this expression. + */ + private def resolveNesting( + nestedFields: List[String], + expression: Expression, + resolver: Resolver): Expression = { + + (nestedFields, expression.dataType) match { + case (Nil, _) => expression + case (requestedField :: rest, StructType(fields)) => + val actualField = fields.filter(f => resolver(f.name, requestedField)) + actualField match { + case Seq() => + sys.error( + s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}") + case Seq(singleMatch) => + resolveNesting(rest, GetField(expression, singleMatch.name), resolver) + case multipleMatches => + sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}") + } + case (_, dt) => sys.error(s"Can't access nested field in type $dt") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8e8259cae6670..391508279bb80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -165,32 +165,6 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output = child.output.map(_.withQualifiers(alias :: Nil)) } -/** - * Converts the schema of `child` to all lowercase, together with LowercaseAttributeReferences - * this allows for optional case insensitive attribute resolution. This node can be elided after - * analysis. - */ -case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { - protected def lowerCaseSchema(dataType: DataType): DataType = dataType match { - case StructType(fields) => - StructType(fields.map(f => - StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable))) - case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull) - case otherType => otherType - } - - override val output = child.output.map { - case a: AttributeReference => - AttributeReference( - a.name.toLowerCase, - lowerCaseSchema(a.dataType), - a.nullable)( - a.exprId, - a.qualifiers) - case other => other - } -} - case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7dbaf7faff0c0..b245e1a863cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -246,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(None, tableName, rdd.queryExecution.analyzed) + catalog.registerTable(None, tableName, rdd.queryExecution.logical) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 15f6bcef93886..08376eb5e5c4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -381,7 +381,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3349 partitioning after limit") { - /* sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) .registerTempTable("subset1") @@ -396,7 +395,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), (1, "a", 1) :: (2, "b", 2) :: Nil) - */ } test("mixed-case keywords") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index e0be09e6793ea..3e1a7b71528e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -244,15 +244,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient - override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { - override def lookupRelation( - databaseName: Option[String], - tableName: String, - alias: Option[String] = None): LogicalPlan = { - - LowerCaseSchema(super.lookupRelation(databaseName, tableName, alias)) - } - } + override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog // Note that HiveUDFs will be overridden by functions registered in this context. @transient diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2c0db9be57e54..6b4399e852c7b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -129,14 +129,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable( - LowerCaseSchema(table: MetastoreRelation), _, child, _) => + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - LowerCaseSchema( InMemoryRelation(_, _, _, - HiveTableScan(_, table, _))), _, child, _) => + HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 43dd3d234f73a..8ac17f37201a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} @@ -55,7 +55,7 @@ private[hive] trait HiveStrategies { object ParquetConversion extends Strategy { implicit class LogicalPlanHacks(s: SchemaRDD) { def lowerCase = - new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan)) + new SchemaRDD(s.sqlContext, s.logicalPlan) def addPartitioningAttributes(attrs: Seq[Attribute]) = new SchemaRDD( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 71ea774d77795..1017fe6d5396d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -21,7 +21,6 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LowerCaseSchema import org.apache.spark.sql.execution.{SparkPlan, Command, LeafNode} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.MetastoreRelation @@ -52,8 +51,7 @@ case class CreateTableAsSelect( sc.catalog.createTable(database, tableName, query.output, false) // Get the Metastore Relation sc.catalog.lookupRelation(Some(database), tableName, None) match { - case LowerCaseSchema(r: MetastoreRelation) => r - case o: MetastoreRelation => o + case r: MetastoreRelation => r } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 5a0e6c5cc1bba..19ff3b66ad7ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -44,10 +44,11 @@ private[hive] abstract class HiveFunctionRegistry def lookupFunction(name: String, children: Seq[Expression]): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. - val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( - sys.error(s"Couldn't find function $name")) + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $name")) - val functionClassName = functionInfo.getFunctionClass.getName() + val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] diff --git a/sql/hive/src/test/resources/golden/database.table table.attr case insensitive-0-98b2e34c9134208e9fe7c62d33010005 b/sql/hive/src/test/resources/golden/database.table table.attr case insensitive-0-98b2e34c9134208e9fe7c62d33010005 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/database.table table.attr case insensitive-0-98b2e34c9134208e9fe7c62d33010005 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b6be6bc1bfefe..ee9d08ff75450 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -36,6 +36,9 @@ class HiveResolutionSuite extends HiveComparisonTest { createQueryTest("database.table table.attr", "SELECT src.key FROM default.src ORDER BY key LIMIT 1") + createQueryTest("database.table table.attr case insensitive", + "SELECT SRC.Key FROM Default.Src ORDER BY key LIMIT 1") + createQueryTest("alias.attr", "SELECT a.key FROM src a ORDER BY key LIMIT 1") @@ -56,14 +59,18 @@ class HiveResolutionSuite extends HiveComparisonTest { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) .registerTempTable("caseSensitivityTest") - sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") - - println(sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").queryExecution) - - sql("SELECT * FROM casesensitivitytest one JOIN casesensitivitytest two ON one.a = two.a").collect() + val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), + "The output schema did not preserve the case of the query.") + query.collect() + } - // TODO: sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a") + ignore("case insensitivity with scala reflection joins") { + // Test resolution with Scala Reflection + TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .registerTempTable("caseSensitivityTest") + sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { From 8e875d2aff5f30a5f7a4bf694fc89a8b852fdcdc Mon Sep 17 00:00:00 2001 From: WangTao Date: Sat, 20 Sep 2014 19:07:07 -0700 Subject: [PATCH 055/315] [SPARK-3599]Avoid loading properties file frequently https://issues.apache.org/jira/browse/SPARK-3599 Author: WangTao Author: WangTaoTheTonic Closes #2454 from WangTaoTheTonic/avoidLoadingFrequently and squashes the following commits: 3681182 [WangTao] do not use clone 7dca036 [WangTao] use lazy val instead 2a79f26 [WangTaoTheTonic] Avoid loaing properties file frequently --- .../org/apache/spark/deploy/SparkSubmit.scala | 2 +- .../spark/deploy/SparkSubmitArguments.scala | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d132ecb3f9989..580a439c9a892 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -280,7 +280,7 @@ object SparkSubmit { } // Read from default spark properties, if any - for ((k, v) <- args.getDefaultSparkProperties) { + for ((k, v) <- args.defaultSparkProperties) { sysProps.getOrElseUpdate(k, v) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index d545f58c5da7e..92e0917743ed1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -57,12 +57,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { var pyFiles: String = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() - parseOpts(args.toList) - mergeSparkProperties() - checkRequiredArguments() - - /** Return default present in the currently defined defaults file. */ - def getDefaultSparkProperties = { + /** Default properties present in the currently defined defaults file. */ + lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => @@ -79,6 +75,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { defaultProperties } + parseOpts(args.toList) + mergeSparkProperties() + checkRequiredArguments() + /** * Fill in any undefined values based on the default properties file or options passed in through * the '--conf' flag. @@ -107,7 +107,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { } } - val properties = getDefaultSparkProperties + val properties = HashMap[String, String]() + properties.putAll(defaultSparkProperties) properties.putAll(sparkProperties) // Use properties file as fallback for values which have a direct analog to @@ -213,7 +214,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | verbose $verbose | |Default properties from $propertiesFile: - |${getDefaultSparkProperties.mkString(" ", "\n ", "\n")} + |${defaultSparkProperties.mkString(" ", "\n ", "\n")} """.stripMargin } From d112a6c79dee7b5d8459696f97d329190e8d09a5 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 20 Sep 2014 23:11:05 -0700 Subject: [PATCH 056/315] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1328 (close requested by 'pwendell') Closes #2314 (close requested by 'pwendell') Closes #997 (close requested by 'pwendell') Closes #550 (close requested by 'pwendell') Closes #1506 (close requested by 'pwendell') Closes #2423 (close requested by 'mengxr') Closes #554 (close requested by 'joshrosen') From a0454efe21e5c7ffe1b9bb7b18021a5580952e69 Mon Sep 17 00:00:00 2001 From: Ian Hummel Date: Sun, 21 Sep 2014 13:04:36 -0700 Subject: [PATCH 057/315] [SPARK-3595] Respect configured OutputCommitters when calling saveAsHadoopFile Addresses the issue in https://issues.apache.org/jira/browse/SPARK-3595, namely saveAsHadoopFile hardcoding the OutputCommitter. This is not ideal when running Spark jobs that write to S3, especially when running them from an EMR cluster where the default OutputCommitter is a DirectOutputCommitter. Author: Ian Hummel Closes #2450 from themodernlife/spark-3595 and squashes the following commits: f37a0e5 [Ian Hummel] Update based on comments from pwendell a11d9f3 [Ian Hummel] Fix formatting 4359664 [Ian Hummel] Add an example showing usage 8b6be94 [Ian Hummel] Add ability to specify OutputCommitter, espcially useful when writing to an S3 bucket from an EMR cluster --- .../org/apache/spark/SparkHadoopWriter.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 7 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 107 ++++++++++++++---- 3 files changed, 91 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index f6703986bdf11..376e69cd997d5 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -116,7 +116,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } } } else { - logWarning ("No need to commit output of task: " + taID.value) + logInfo ("No need to commit output of task: " + taID.value) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index f6d9d12fe9006..51ba8c2d17834 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -872,7 +872,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) } - hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) + + // Use configured output committer if already set + if (conf.getOutputCommitter == null) { + hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) + } + FileOutputFormat.setOutputPath(hadoopConf, SparkHadoopWriter.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 63d3ddb4af98a..e84cc69592339 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.rdd -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util.Progressable + +import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random -import org.scalatest.FunSuite import com.google.common.io.Files -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.conf.{Configuration, Configurable} - -import org.apache.spark.SparkContext._ +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, +OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, +TaskAttemptContext => NewTaskAttempContext} import org.apache.spark.{Partitioner, SharedSparkContext} +import org.apache.spark.SparkContext._ +import org.scalatest.FunSuite class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { test("aggregateByKey") { @@ -467,7 +471,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) // No error, non-configurable formats still work - pairs.saveAsNewAPIHadoopFile[FakeFormat]("ignored") + pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored") /* Check that configurable formats get configured: @@ -478,6 +482,17 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored") } + test("saveAsHadoopFile should respect configured output committers") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(1)))) + val conf = new JobConf() + conf.setOutputCommitter(classOf[FakeOutputCommitter]) + + FakeOutputCommitter.ran = false + pairs.saveAsHadoopFile("ignored", pairs.keyClass, pairs.valueClass, classOf[FakeOutputFormat], conf) + + assert(FakeOutputCommitter.ran, "OutputCommitter was never called") + } + test("lookup") { val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) @@ -621,40 +636,86 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile tries to instantiate them with Class.newInstance. */ + +/* + * Original Hadoop API + */ class FakeWriter extends RecordWriter[Integer, Integer] { + override def write(key: Integer, value: Integer): Unit = () - def close(p1: TaskAttemptContext) = () + override def close(reporter: Reporter): Unit = () +} + +class FakeOutputCommitter() extends OutputCommitter() { + override def setupJob(jobContext: JobContext): Unit = () + + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + + override def setupTask(taskContext: TaskAttemptContext): Unit = () + + override def commitTask(taskContext: TaskAttemptContext): Unit = { + FakeOutputCommitter.ran = true + () + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = () +} + +/* + * Used to communicate state between the test harness and the OutputCommitter. + */ +object FakeOutputCommitter { + var ran = false +} + +class FakeOutputFormat() extends OutputFormat[Integer, Integer]() { + override def getRecordWriter( + ignored: FileSystem, + job: JobConf, name: String, + progress: Progressable): RecordWriter[Integer, Integer] = { + new FakeWriter() + } + + override def checkOutputSpecs(ignored: FileSystem, job: JobConf): Unit = () +} + +/* + * New-style Hadoop API + */ +class NewFakeWriter extends NewRecordWriter[Integer, Integer] { + + def close(p1: NewTaskAttempContext) = () def write(p1: Integer, p2: Integer) = () } -class FakeCommitter extends OutputCommitter { - def setupJob(p1: JobContext) = () +class NewFakeCommitter extends NewOutputCommitter { + def setupJob(p1: NewJobContext) = () - def needsTaskCommit(p1: TaskAttemptContext): Boolean = false + def needsTaskCommit(p1: NewTaskAttempContext): Boolean = false - def setupTask(p1: TaskAttemptContext) = () + def setupTask(p1: NewTaskAttempContext) = () - def commitTask(p1: TaskAttemptContext) = () + def commitTask(p1: NewTaskAttempContext) = () - def abortTask(p1: TaskAttemptContext) = () + def abortTask(p1: NewTaskAttempContext) = () } -class FakeFormat() extends OutputFormat[Integer, Integer]() { +class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { - def checkOutputSpecs(p1: JobContext) = () + def checkOutputSpecs(p1: NewJobContext) = () - def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = { - new FakeWriter() + def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { + new NewFakeWriter() } - def getOutputCommitter(p1: TaskAttemptContext): OutputCommitter = { - new FakeCommitter() + def getOutputCommitter(p1: NewTaskAttempContext): NewOutputCommitter = { + new NewFakeCommitter() } } -class ConfigTestFormat() extends FakeFormat() with Configurable { +class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false def setConf(p1: Configuration) = { @@ -664,7 +725,7 @@ class ConfigTestFormat() extends FakeFormat() with Configurable { def getConf: Configuration = null - override def getRecordWriter(p1: TaskAttemptContext): RecordWriter[Integer, Integer] = { + override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { assert(setConfCalled, "setConf was never called") super.getRecordWriter(p1) } From fd0b32c520e3d1088b2fe9228be114932e6c3a0c Mon Sep 17 00:00:00 2001 From: wangfei Date: Sun, 21 Sep 2014 13:09:36 -0700 Subject: [PATCH 058/315] [Minor]ignore .idea_modules ignore .idea_modules , ```sbt/sbt gen-idea``` generate this dir. Author: wangfei Closes #2476 from scwf/patch-4 and squashes the following commits: e6ab88a [wangfei] ignore .idea_modules --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1bcd0165761ac..7779980b74a22 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ *.iml *.iws .idea/ +.idea_modules/ sbt/*.jar .settings .cache From fec921552ffccc36937214406b3e4a050eb0d8e0 Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Mon, 22 Sep 2014 09:10:41 -0700 Subject: [PATCH 059/315] [MLLib] Fix example code variable name misspelling in MLLib Feature Extraction guide Author: RJ Nowling Closes #2459 from rnowling/tfidf-fix and squashes the following commits: b370a91 [RJ Nowling] Fix variable name misspelling in MLLib Feature Extraction guide --- docs/mllib-feature-extraction.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 44f0f76220b6e..41a27f6208d1b 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -68,7 +68,7 @@ val sc: SparkContext = ... val documents: RDD[Seq[String]] = sc.textFile("...").map(_.split(" ").toSeq) val hashingTF = new HashingTF() -val tf: RDD[Vector] = hasingTF.transform(documents) +val tf: RDD[Vector] = hashingTF.transform(documents) {% endhighlight %} While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes: From 56dae30ca70489a62686cb245728b09b2179bb5a Mon Sep 17 00:00:00 2001 From: Grega Kespret Date: Mon, 22 Sep 2014 10:13:44 -0700 Subject: [PATCH 060/315] Update docs to use jsonRDD instead of wrong jsonRdd. Author: Grega Kespret Closes #2479 from gregakespret/patch-1 and squashes the following commits: dd6b90a [Grega Kespret] Update docs to use jsonRDD instead of wrong jsonRdd. --- docs/sql-programming-guide.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5212e19c41349..c1f80544bf0af 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -605,7 +605,7 @@ Spark SQL can automatically infer the schema of a JSON dataset and load it as a This conversion can be done using one of two methods in a SQLContext: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. {% highlight scala %} // sc is an existing SparkContext. @@ -643,7 +643,7 @@ Spark SQL can automatically infer the schema of a JSON dataset and load it as a This conversion can be done using one of two methods in a JavaSQLContext : * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. {% highlight java %} // sc is an existing JavaSparkContext. @@ -681,7 +681,7 @@ Spark SQL can automatically infer the schema of a JSON dataset and load it as a This conversion can be done using one of two methods in a SQLContext: * `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRdd` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. {% highlight python %} # sc is an existing SparkContext. From f9d6220c792b779be385f3022d146911a22c2130 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 22 Sep 2014 13:47:43 -0700 Subject: [PATCH 061/315] [SPARK-3578] Fix upper bound in GraphGenerators.sampleLogNormal GraphGenerators.sampleLogNormal is supposed to return an integer strictly less than maxVal. However, it violates this guarantee. It generates its return value as follows: ```scala var X: Double = maxVal while (X >= maxVal) { val Z = rand.nextGaussian() X = math.exp(mu + sigma*Z) } math.round(X.toFloat) ``` When X is sampled to be close to (but less than) maxVal, then it will pass the while loop condition, but the rounded result will be equal to maxVal, which will violate the guarantee. For example, if maxVal is 5 and X is 4.9, then X < maxVal, but `math.round(X.toFloat)` is 5. This PR instead rounds X before checking the loop condition, guaranteeing that the condition will hold for the return value. Author: Ankur Dave Closes #2439 from ankurdave/SPARK-3578 and squashes the following commits: f6655e5 [Ankur Dave] Go back to math.floor 5900c22 [Ankur Dave] Round X in loop condition 6fd5fb1 [Ankur Dave] Run sampleLogNormal bounds check 1000 times 1638598 [Ankur Dave] Round down in sampleLogNormal to guarantee upper bound --- .../org/apache/spark/graphx/util/GraphGenerators.scala | 2 +- .../apache/spark/graphx/util/GraphGeneratorsSuite.scala | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index b8309289fe475..8a13c74221546 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -118,7 +118,7 @@ object GraphGenerators { val Z = rand.nextGaussian() X = math.exp(mu + sigma*Z) } - math.round(X.toFloat) + math.floor(X).toInt } /** diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index b346d4db2ef96..3abefbe52fa8a 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -64,8 +64,11 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { val sigma = 1.3 val maxVal = 100 - val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal) - assert(dstId < maxVal) + val trials = 1000 + for (i <- 1 to trials) { + val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal) + assert(dstId < maxVal) + } val dstId_round1 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) val dstId_round2 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) From 14f8c340402366cb998c563b3f7d9ff7d9940271 Mon Sep 17 00:00:00 2001 From: "peng.zhang" Date: Tue, 23 Sep 2014 08:45:56 -0500 Subject: [PATCH 062/315] [YARN] SPARK-2668: Add variable of yarn log directory for reference from the log4j configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assign value of yarn container log directory to java opts "spark.yarn.app.container.log.dir", So user defined log4j.properties can reference this value and write log to YARN container's log directory. Otherwise, user defined file appender will only write to container's CWD, and log files in CWD will not be displayed on YARN UI,and either cannot be aggregated to HDFS log directory after job finished. User defined log4j.properties reference example: log4j.appender.rolling_file.File = ${spark.yarn.app.container.log.dir}/spark.log Author: peng.zhang Closes #1573 from renozhang/yarn-log-dir and squashes the following commits: 16c5cb8 [peng.zhang] Update doc f2b5e2a [peng.zhang] Change variable's name, and update running-on-yarn.md 503ea2d [peng.zhang] Support log4j log to yarn container dir --- docs/running-on-yarn.md | 2 ++ .../main/scala/org/apache/spark/deploy/yarn/ClientBase.scala | 3 +++ .../org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala | 3 +++ 3 files changed, 8 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 74bcc2eeb65f6..4b3a49eca7007 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -205,6 +205,8 @@ Note that for the first option, both executors and the application master will s log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use "${spark.yarn.app.container.log.dir}" in your log4j.properties. For example, log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log. For streaming application, configuring RollingFileAppender and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + # Important notes - Before Hadoop 2.2, YARN does not support cores in container resource requests. Thus, when running against an earlier version, the numbers of cores given via command line arguments cannot be passed to YARN. Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index c96f731923d22..6ae4d496220a5 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -388,6 +388,9 @@ trait ClientBase extends Logging { .foreach(p => javaOpts += s"-Djava.library.path=$p") } + // For log4j configuration to reference + javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + val userClass = if (args.userClass != null) { Seq("--class", YarnSparkHadoopUtil.escapeForShell(args.userClass)) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 312d82a649792..f56f72cafe50e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -98,6 +98,9 @@ trait ExecutorRunnableUtil extends Logging { } */ + // For log4j configuration to reference + javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server", // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. From c4022dd52b4827323ff956632dc7623f546da937 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 23 Sep 2014 11:20:52 -0500 Subject: [PATCH 063/315] [SPARK-3477] Clean up code in Yarn Client / ClientBase This is part of a broader effort to clean up the Yarn integration code after #2020. The high-level changes in this PR include: - Removing duplicate code, especially across the alpha and stable APIs - Simplify unnecessarily complex method signatures and hierarchies - Rename unclear variable and method names - Organize logging output produced when the user runs Spark on Yarn - Extensively add documentation - Privatize classes where possible I have tested the stable API on a Hadoop 2.4 cluster. I tested submitting a jar that references classes in other jars in both client and cluster mode. I also made changes in the alpha API, though I do not have access to an alpha cluster. I have verified that it compiles, but it would be ideal if others can help test it. For those interested in some examples in detail, please read on. -------------------------------------------------------------------------------------------------------- ***Appendix*** - The loop to `getApplicationReport` from the RM is duplicated in 4 places: in the stable `Client`, alpha `Client`, and twice in `YarnClientSchedulerBackend`. We should not have different loops for client and cluster deploy modes. - There are many fragmented small helper methods that are only used once and should just be inlined. For instance, `ClientBase#getLocalPath` returns `null` on certain conditions, and its only caller `ClientBase#addFileToClasspath` checks whether the value returned is `null`. We could just have the caller check on that same condition to avoid passing `null`s around. - In `YarnSparkHadoopUtil#addToEnvironment`, we take in an argument `classpathSeparator` that always has the same value upstream (i.e. `File.pathSeparator`). This argument is now removed from the signature and all callers of this method upstream. - `ClientBase#copyRemoteFile` is now renamed to `copyFileToRemote`. It was unclear whether we are copying a remote file to our local file system, or copying a locally visible file to a remote file system. Also, even the content of the method has inaccurately named variables. We use `val remoteFs` to signify the file system of the locally visible file and `val fs` to signify the remote, destination file system. These are now renamed `srcFs` and `destFs` respectively. - We currently log the AM container's environment and resource mappings directly as Scala collections. This is incredibly hard to read and probably too verbose for the average Spark user. In other modes (e.g. standalone), we also don't log the launch commands by default, so the logging level of these information is now set to `DEBUG`. - None of these classes (`Client`, `ClientBase`, `YarnSparkHadoopUtil` etc.) is intended to be used by a Spark application (the user should go through Spark submit instead). At the very least they should be `private[spark]`. Author: Andrew Or Closes #2350 from andrewor14/yarn-cleanup and squashes the following commits: 39e8c7b [Andrew Or] Address review comments 6619f9b [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-cleanup 2ca6d64 [Andrew Or] Improve logging in application monitor a3b9693 [Andrew Or] Minor changes 7dd6298 [Andrew Or] Simplify ClientBase#monitorApplication 547487c [Andrew Or] Provide default values for null application report entries a0ad1e9 [Andrew Or] Fix class not found error 1590141 [Andrew Or] Address review comments 45ccdea [Andrew Or] Remove usages of getAMMemory d8e33b6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-cleanup ed0b42d [Andrew Or] Fix alpha compilation error c0587b4 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-cleanup 6d74888 [Andrew Or] Minor comment changes 6573c1d [Andrew Or] Clean up, simplify and document code for setting classpaths e4779b6 [Andrew Or] Clean up log messages + variable naming in ClientBase 8766d37 [Andrew Or] Heavily add documentation to Client* classes + various clean-ups 6c94d79 [Andrew Or] Various cleanups in ClientBase and ClientArguments ef7069a [Andrew Or] Clean up YarnClientSchedulerBackend more 6de9072 [Andrew Or] Guard against potential NPE in debug logging mode fabe4c4 [Andrew Or] Reuse more code in YarnClientSchedulerBackend 3f941dc [Andrew Or] First cut at simplifying the Client (stable and alpha) --- .../org/apache/spark/deploy/yarn/Client.scala | 145 ++-- .../spark/deploy/yarn/ClientArguments.scala | 67 +- .../apache/spark/deploy/yarn/ClientBase.scala | 682 +++++++++++------- .../yarn/ClientDistributedCacheManager.scala | 97 +-- .../deploy/yarn/ExecutorRunnableUtil.scala | 16 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 63 +- .../cluster/YarnClientSchedulerBackend.scala | 145 ++-- .../spark/deploy/yarn/ClientBaseSuite.scala | 18 +- .../org/apache/spark/deploy/yarn/Client.scala | 167 ++--- 9 files changed, 738 insertions(+), 662 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index aff9ab71f0937..5a20532315e59 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -23,13 +23,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.YarnClientImpl import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{Apps, Records} +import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil @@ -37,7 +35,10 @@ import org.apache.spark.deploy.SparkHadoopUtil /** * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's alpha API. */ -class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: SparkConf) +private[spark] class Client( + val args: ClientArguments, + val hadoopConf: Configuration, + val sparkConf: SparkConf) extends YarnClientImpl with ClientBase with Logging { def this(clientArgs: ClientArguments, spConf: SparkConf) = @@ -45,112 +46,86 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - val args = clientArgs - val conf = hadoopConf - val sparkConf = spConf - var rpc: YarnRPC = YarnRPC.create(conf) - val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(hadoopConf) + /* ------------------------------------------------------------------------------------- * + | The following methods have much in common in the stable and alpha versions of Client, | + | but cannot be implemented in the parent trait due to subtle API differences across | + | hadoop versions. | + * ------------------------------------------------------------------------------------- */ - // for client user who want to monitor app status by itself. - def runApp() = { - validateArgs() - + /** Submit an application running our ApplicationMaster to the ResourceManager. */ + override def submitApplication(): ApplicationId = { init(yarnConf) start() - logClusterResourceDetails() - val newApp = super.getNewApplication() - val appId = newApp.getApplicationId() + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(getYarnClusterMetrics.getNumNodeManagers)) - verifyClusterResources(newApp) - val appContext = createApplicationSubmissionContext(appId) - val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val env = setupLaunchEnv(localResources, appStagingDir) - val amContainer = createContainerLaunchContext(newApp, localResources, env) + // Get a new application from our RM + val newAppResponse = getNewApplication() + val appId = newAppResponse.getApplicationId() - val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] - // Memory for the ApplicationMaster. - capability.setMemory(args.amMemory + memoryOverhead) - amContainer.setResource(capability) + // Verify whether the cluster has enough resources for our AM + verifyClusterResources(newAppResponse) - appContext.setQueue(args.amQueue) - appContext.setAMContainerSpec(amContainer) - appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(appId, containerContext) - submitApp(appContext) + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + submitApplication(appContext) appId } - def run() { - val appId = runApp() - monitorApplication(appId) - } - - def logClusterResourceDetails() { - val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics - logInfo("Got cluster metric info from ASM, numNodeManagers = " + - clusterMetrics.getNumNodeManagers) + /** + * Set up a context for launching our ApplicationMaster container. + * In the Yarn alpha API, the memory requirements of this container must be set in + * the ContainerLaunchContext instead of the ApplicationSubmissionContext. + */ + override def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) + : ContainerLaunchContext = { + val containerContext = super.createContainerLaunchContext(newAppResponse) + val capability = Records.newRecord(classOf[Resource]) + capability.setMemory(args.amMemory + amMemoryOverhead) + containerContext.setResource(capability) + containerContext } - - def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { - logInfo("Setting up application submission context for ASM") + /** Set up the context for submitting our ApplicationMaster. */ + def createApplicationSubmissionContext( + appId: ApplicationId, + containerContext: ContainerLaunchContext): ApplicationSubmissionContext = { val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) appContext.setApplicationId(appId) appContext.setApplicationName(args.appName) + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(containerContext) + appContext.setUser(UserGroupInformation.getCurrentUser.getShortUserName) appContext } - def setupSecurityToken(amContainer: ContainerLaunchContext) = { - // Setup security tokens. + /** + * Set up security tokens for launching our ApplicationMaster container. + * ContainerLaunchContext#setContainerTokens is renamed `setTokens` in the stable API. + */ + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { val dob = new DataOutputBuffer() credentials.writeTokenStorageToStream(dob) amContainer.setContainerTokens(ByteBuffer.wrap(dob.getData())) } - def submitApp(appContext: ApplicationSubmissionContext) = { - // Submit the application to the applications manager. - logInfo("Submitting application to ASM") - super.submitApplication(appContext) - } - - def monitorApplication(appId: ApplicationId): Boolean = { - val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) - - while (true) { - Thread.sleep(interval) - val report = super.getApplicationReport(appId) - - logInfo("Application report from ASM: \n" + - "\t application identifier: " + appId.toString() + "\n" + - "\t appId: " + appId.getId() + "\n" + - "\t clientToken: " + report.getClientToken() + "\n" + - "\t appDiagnostics: " + report.getDiagnostics() + "\n" + - "\t appMasterHost: " + report.getHost() + "\n" + - "\t appQueue: " + report.getQueue() + "\n" + - "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + - "\t appStartTime: " + report.getStartTime() + "\n" + - "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + - "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + - "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + - "\t appUser: " + report.getUser() - ) - - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - return true - } - } - true - } + /** + * Return the security token used by this client to communicate with the ApplicationMaster. + * If no security is enabled, the token returned by the report is null. + * ApplicationReport#getClientToken is renamed `getClientToAMToken` in the stable API. + */ + override def getClientToken(report: ApplicationReport): String = + Option(report.getClientToken).getOrElse("") } object Client { - def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a " + @@ -158,19 +133,17 @@ object Client { } // Set an env variable indicating we are running in YARN mode. - // Note that anything with SPARK prefix gets propagated to all (remote) processes + // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") - val sparkConf = new SparkConf try { val args = new ClientArguments(argStrings, sparkConf) new Client(args, sparkConf).run() } catch { - case e: Exception => { + case e: Exception => Console.err.println(e.getMessage) System.exit(1) - } } System.exit(0) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 40d8d6d6e6961..201b742736c6e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -17,15 +17,14 @@ package org.apache.spark.deploy.yarn -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf -import org.apache.spark.scheduler.InputFormatInfo import org.apache.spark.util.{Utils, IntParam, MemoryParam} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! -class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { +private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { var addJars: String = null var files: String = null var archives: String = null @@ -35,28 +34,56 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { var executorMemory = 1024 // MB var executorCores = 1 var numExecutors = 2 - var amQueue = sparkConf.get("QUEUE", "default") + var amQueue = sparkConf.get("spark.yarn.queue", "default") var amMemory: Int = 512 // MB var appName: String = "Spark" var priority = 0 - parseArgs(args.toList) + // Additional memory to allocate to containers + // For now, use driver's memory overhead as our AM container's memory overhead + val amMemoryOverhead = sparkConf.getInt( + "spark.yarn.driver.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) + val executorMemoryOverhead = sparkConf.getInt( + "spark.yarn.executor.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - // env variable SPARK_YARN_DIST_ARCHIVES/SPARK_YARN_DIST_FILES set in yarn-client then - // it should default to hdfs:// - files = Option(files).getOrElse(sys.env.get("SPARK_YARN_DIST_FILES").orNull) - archives = Option(archives).getOrElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES").orNull) + parseArgs(args.toList) + loadEnvironmentArgs() + validateArgs() + + /** Load any default arguments provided through environment variables and Spark properties. */ + private def loadEnvironmentArgs(): Unit = { + // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://, + // while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051). + files = Option(files) + .orElse(sys.env.get("SPARK_YARN_DIST_FILES")) + .orElse(sparkConf.getOption("spark.yarn.dist.files").map(p => Utils.resolveURIs(p))) + .orNull + archives = Option(archives) + .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) + .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p))) + .orNull + } - // spark.yarn.dist.archives/spark.yarn.dist.files defaults to use file:// if not specified, - // for both yarn-client and yarn-cluster - files = Option(files).getOrElse(sparkConf.getOption("spark.yarn.dist.files"). - map(p => Utils.resolveURIs(p)).orNull) - archives = Option(archives).getOrElse(sparkConf.getOption("spark.yarn.dist.archives"). - map(p => Utils.resolveURIs(p)).orNull) + /** + * Fail fast if any arguments provided are invalid. + * This is intended to be called only after the provided arguments have been parsed. + */ + private def validateArgs(): Unit = { + // TODO: memory checks are outdated (SPARK-3476) + Map[Boolean, String]( + (numExecutors <= 0) -> "You must specify at least 1 executor!", + (amMemory <= amMemoryOverhead) -> s"AM memory must be > $amMemoryOverhead MB", + (executorMemory <= executorMemoryOverhead) -> + s"Executor memory must be > $executorMemoryOverhead MB" + ).foreach { case (errorCondition, errorMessage) => + if (errorCondition) { + throw new IllegalArgumentException(errorMessage + "\n" + getUsageMessage()) + } + } + } private def parseArgs(inputArgs: List[String]): Unit = { - val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() - + val userArgsBuffer = new ArrayBuffer[String]() var args = inputArgs while (!args.isEmpty) { @@ -138,16 +165,14 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { userArgs = userArgsBuffer.readOnly } - - def getUsageMessage(unknownParam: Any = null): String = { + private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" - message + "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + " --class CLASS_NAME Name of your application's main class (required)\n" + - " --arg ARGS Argument to be passed to your application's main class.\n" + + " --arg ARG Argument to be passed to your application's main class.\n" + " Multiple invocations are possible, each will be passed in order.\n" + " --num-executors NUM Number of executors to start (Default: 2)\n" + " --executor-cores NUM Number of cores for the executors (Default: 1).\n" + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6ae4d496220a5..4870b0cb3ddaf 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.yarn -import java.io.File import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import scala.collection.JavaConversions._ @@ -37,154 +36,107 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records + import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} /** - * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The - * Client submits an application to the YARN ResourceManager. + * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. + * The Client submits an application to the YARN ResourceManager. */ -trait ClientBase extends Logging { - val args: ClientArguments - val conf: Configuration - val sparkConf: SparkConf - val yarnConf: YarnConfiguration - val credentials = UserGroupInformation.getCurrentUser().getCredentials() - private val SPARK_STAGING: String = ".sparkStaging" +private[spark] trait ClientBase extends Logging { + import ClientBase._ + + protected val args: ClientArguments + protected val hadoopConf: Configuration + protected val sparkConf: SparkConf + protected val yarnConf: YarnConfiguration + protected val credentials = UserGroupInformation.getCurrentUser.getCredentials + protected val amMemoryOverhead = args.amMemoryOverhead // MB + protected val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() - // Staging directory is private! -> rwx-------- - val STAGING_DIR_PERMISSION: FsPermission = - FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) - // App files are world-wide readable and owner writable -> rw-r--r-- - val APP_FILE_PERMISSION: FsPermission = - FsPermission.createImmutable(Integer.parseInt("644", 8).toShort) - - // Additional memory overhead - in mb. - protected def memoryOverhead: Int = sparkConf.getInt("spark.yarn.driver.memoryOverhead", - YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - - // TODO(harvey): This could just go in ClientArguments. - def validateArgs() = { - Map( - (args.numExecutors <= 0) -> "Error: You must specify at least 1 executor!", - (args.amMemory <= memoryOverhead) -> ("Error: AM memory size must be" + - "greater than: " + memoryOverhead), - (args.executorMemory <= memoryOverhead) -> ("Error: Executor memory size" + - "must be greater than: " + memoryOverhead.toString) - ).foreach { case(cond, errStr) => - if (cond) { - logError(errStr) - throw new IllegalArgumentException(args.getUsageMessage()) - } - } - } - - def getAppStagingDir(appId: ApplicationId): String = { - SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR - } - - def verifyClusterResources(app: GetNewApplicationResponse) = { - val maxMem = app.getMaximumResourceCapability().getMemory() - logInfo("Max mem capabililty of a single resource in this cluster " + maxMem) - - // If we have requested more then the clusters max for a single resource then exit. - if (args.executorMemory > maxMem) { - val errorMessage = - "Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster." - .format(args.executorMemory, maxMem) - - logError(errorMessage) - throw new IllegalArgumentException(errorMessage) - } - val amMem = args.amMemory + memoryOverhead + /** + * Fail fast if we have requested more resources per container than is available in the cluster. + */ + protected def verifyClusterResources(newAppResponse: GetNewApplicationResponse): Unit = { + val maxMem = newAppResponse.getMaximumResourceCapability().getMemory() + logInfo("Verifying our application has not requested more than the maximum " + + s"memory capability of the cluster ($maxMem MB per container)") + val executorMem = args.executorMemory + executorMemoryOverhead + if (executorMem > maxMem) { + throw new IllegalArgumentException(s"Required executor memory ($executorMem MB) " + + s"is above the max threshold ($maxMem MB) of this cluster!") + } + val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - - val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." - .format(amMem, maxMem) - logError(errorMessage) - throw new IllegalArgumentException(errorMessage) + throw new IllegalArgumentException(s"Required AM memory ($amMem MB) " + + s"is above the max threshold ($maxMem MB) of this cluster!") } - // We could add checks to make sure the entire cluster has enough resources but that involves // getting all the node reports and computing ourselves. } - /** See if two file systems are the same or not. */ - private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { - val srcUri = srcFs.getUri() - val dstUri = destFs.getUri() - if (srcUri.getScheme() == null) { - return false - } - if (!srcUri.getScheme().equals(dstUri.getScheme())) { - return false - } - var srcHost = srcUri.getHost() - var dstHost = dstUri.getHost() - if ((srcHost != null) && (dstHost != null)) { - try { - srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() - dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() - } catch { - case e: UnknownHostException => - return false - } - if (!srcHost.equals(dstHost)) { - return false - } - } else if (srcHost == null && dstHost != null) { - return false - } else if (srcHost != null && dstHost == null) { - return false - } - if (srcUri.getPort() != dstUri.getPort()) { - false - } else { - true - } - } - - /** Copy the file into HDFS if needed. */ - private[yarn] def copyRemoteFile( - dstDir: Path, - originalPath: Path, + /** + * Copy the given file to a remote file system (e.g. HDFS) if needed. + * The file is only copied if the source and destination file systems are different. This is used + * for preparing resources for launching the ApplicationMaster container. Exposed for testing. + */ + def copyFileToRemote( + destDir: Path, + srcPath: Path, replication: Short, setPerms: Boolean = false): Path = { - val fs = FileSystem.get(conf) - val remoteFs = originalPath.getFileSystem(conf) - var newPath = originalPath - if (!compareFs(remoteFs, fs)) { - newPath = new Path(dstDir, originalPath.getName()) - logInfo("Uploading " + originalPath + " to " + newPath) - FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) - fs.setReplication(newPath, replication) - if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) + val destFs = destDir.getFileSystem(hadoopConf) + val srcFs = srcPath.getFileSystem(hadoopConf) + var destPath = srcPath + if (!compareFs(srcFs, destFs)) { + destPath = new Path(destDir, srcPath.getName()) + logInfo(s"Uploading resource $srcPath -> $destPath") + FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf) + destFs.setReplication(destPath, replication) + if (setPerms) { + destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) + } + } else { + logInfo(s"Source and destination file systems are the same. Not copying $srcPath") } // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific // version shows the specific version in the distributed cache configuration - val qualPath = fs.makeQualified(newPath) - val fc = FileContext.getFileContext(qualPath.toUri(), conf) - val destPath = fc.resolvePath(qualPath) - destPath + val qualifiedDestPath = destFs.makeQualified(destPath) + val fc = FileContext.getFileContext(qualifiedDestPath.toUri(), hadoopConf) + fc.resolvePath(qualifiedDestPath) } - private def qualifyForLocal(localURI: URI): Path = { - var qualifiedURI = localURI - // If not specified, assume these are in the local filesystem to keep behavior like Hadoop - if (qualifiedURI.getScheme() == null) { - qualifiedURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(qualifiedURI)).toString) - } + /** + * Given a local URI, resolve it and return a qualified local path that corresponds to the URI. + * This is used for preparing local resources to be included in the container launch context. + */ + private def getQualifiedLocalPath(localURI: URI): Path = { + val qualifiedURI = + if (localURI.getScheme == null) { + // If not specified, assume this is in the local filesystem to keep the behavior + // consistent with that of Hadoop + new URI(FileSystem.getLocal(hadoopConf).makeQualified(new Path(localURI)).toString) + } else { + localURI + } new Path(qualifiedURI) } + /** + * Upload any resources to the distributed cache if needed. If a resource is intended to be + * consumed locally, set up the appropriate config for downstream code to handle it properly. + * This is used for setting up a container launch context for our ApplicationMaster. + * Exposed for testing. + */ def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { - logInfo("Preparing Local resources") - // Upload Spark and the application JAR to the remote file system if necessary. Add them as - // local resources to the application master. - val fs = FileSystem.get(conf) + logInfo("Preparing resources for our AM container") + // Upload Spark and the application JAR to the remote file system if necessary, + // and add them as local resources to the application master. + val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val nns = ClientBase.getNameNodesToAccess(sparkConf) + dst - ClientBase.obtainTokensForNamenodes(nns, conf, credentials) + val nns = getNameNodesToAccess(sparkConf) + dst + obtainTokensForNamenodes(nns, hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort val localResources = HashMap[String, LocalResource]() @@ -200,73 +152,84 @@ trait ClientBase extends Logging { "for alternatives.") } + /** + * Copy the given main resource to the distributed cache if the scheme is not "local". + * Otherwise, set the corresponding key in our SparkConf to handle it downstream. + * Each resource is represented by a 4-tuple of: + * (1) destination resource name, + * (2) local path to the resource, + * (3) Spark property key to set if the scheme is not local, and + * (4) whether to set permissions for this resource + */ List( - (ClientBase.SPARK_JAR, ClientBase.sparkJar(sparkConf), ClientBase.CONF_SPARK_JAR), - (ClientBase.APP_JAR, args.userJar, ClientBase.CONF_SPARK_USER_JAR), - ("log4j.properties", oldLog4jConf.getOrElse(null), null) - ).foreach { case(destName, _localPath, confKey) => + (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false), + (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true), + ("log4j.properties", oldLog4jConf.orNull, null, false) + ).foreach { case (destName, _localPath, confKey, setPermissions) => val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (! localPath.isEmpty()) { + if (!localPath.isEmpty()) { val localURI = new URI(localPath) - if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) { - val setPermissions = destName.equals(ClientBase.APP_JAR) - val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions) - val destFs = FileSystem.get(destPath.toUri(), conf) - distCacheMgr.addResource(destFs, conf, destPath, localResources, LocalResourceType.FILE, - destName, statCache) + if (localURI.getScheme != LOCAL_SCHEME) { + val src = getQualifiedLocalPath(localURI) + val destPath = copyFileToRemote(dst, src, replication, setPermissions) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) + distCacheMgr.addResource(destFs, hadoopConf, destPath, + localResources, LocalResourceType.FILE, destName, statCache) } else if (confKey != null) { + // If the resource is intended for local use only, handle this downstream + // by setting the appropriate property sparkConf.set(confKey, localPath) } } } + /** + * Do the same for any additional resources passed in through ClientArguments. + * Each resource category is represented by a 3-tuple of: + * (1) comma separated list of resources in this category, + * (2) resource type, and + * (3) whether to add these resources to the classpath + */ val cachedSecondaryJarLinks = ListBuffer.empty[String] - val fileLists = List( (args.addJars, LocalResourceType.FILE, true), + List( + (args.addJars, LocalResourceType.FILE, true), (args.files, LocalResourceType.FILE, false), - (args.archives, LocalResourceType.ARCHIVE, false) ) - fileLists.foreach { case (flist, resType, addToClasspath) => + (args.archives, LocalResourceType.ARCHIVE, false) + ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { - flist.split(',').foreach { case file: String => + flist.split(',').foreach { file => val localURI = new URI(file.trim()) - if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) { + if (localURI.getScheme != LOCAL_SCHEME) { val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyRemoteFile(dst, localPath, replication) - distCacheMgr.addResource(fs, conf, destPath, localResources, resType, - linkname, statCache) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache) if (addToClasspath) { cachedSecondaryJarLinks += linkname } } else if (addToClasspath) { + // Resource is intended for local use only and should be added to the class path cachedSecondaryJarLinks += file.trim() } } } } - logInfo("Prepared Local resources " + localResources) - sparkConf.set(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) + if (cachedSecondaryJarLinks.nonEmpty) { + sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) + } - UserGroupInformation.getCurrentUser().addCredentials(credentials) localResources } - /** Get all application master environment variables set on this SparkConf */ - def getAppMasterEnv: Seq[(String, String)] = { - val prefix = "spark.yarn.appMasterEnv." - sparkConf.getAll.filter{case (k, v) => k.startsWith(prefix)} - .map{case (k, v) => (k.substring(prefix.length), v)} - } - - - def setupLaunchEnv( - localResources: HashMap[String, LocalResource], - stagingDir: String): HashMap[String, String] = { - logInfo("Setting up the launch environment") - + /** + * Set up the environment for launching our ApplicationMaster container. + */ + private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() - val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - ClientBase.populateClasspath(args, yarnConf, sparkConf, env, extraCp) + populateClasspath(args, yarnConf, sparkConf, env, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() @@ -275,42 +238,20 @@ trait ClientBase extends Logging { distCacheMgr.setDistFilesEnv(env) distCacheMgr.setDistArchivesEnv(env) - getAppMasterEnv.foreach { case (key, value) => - YarnSparkHadoopUtil.addToEnvironment(env, key, value, File.pathSeparator) - } + // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* + val amEnvPrefix = "spark.yarn.appMasterEnv." + sparkConf.getAll + .filter { case (k, v) => k.startsWith(amEnvPrefix) } + .map { case (k, v) => (k.substring(amEnvPrefix.length), v) } + .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) } // Keep this for backwards compatibility but users should move to the config sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => // Allow users to specify some environment variables. - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs, File.pathSeparator) - + YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments. env("SPARK_YARN_USER_ENV") = userEnvs } - env - } - - def userArgsToString(clientArgs: ClientArguments): String = { - val prefix = " --arg " - val args = clientArgs.userArgs - val retval = new StringBuilder() - for (arg <- args) { - retval.append(prefix).append(" ").append(YarnSparkHadoopUtil.escapeForShell(arg)) - } - retval.toString - } - - def setupSecurityToken(amContainer: ContainerLaunchContext) - - def createContainerLaunchContext( - newApp: GetNewApplicationResponse, - localResources: HashMap[String, LocalResource], - env: HashMap[String, String]): ContainerLaunchContext = { - logInfo("Setting up container launch context") - val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) - amContainer.setLocalResources(localResources) - - val isLaunchingDriver = args.userClass != null // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's @@ -320,6 +261,7 @@ trait ClientBase extends Logging { // Note that to warn the user about the deprecation in cluster mode, some code from // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition // described above). + val isLaunchingDriver = args.userClass != null if (isLaunchingDriver) { sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = @@ -342,14 +284,30 @@ trait ClientBase extends Logging { env("SPARK_JAVA_OPTS") = value } } - amContainer.setEnvironment(env) - val amMemory = args.amMemory + env + } + + /** + * Set up a ContainerLaunchContext to launch our ApplicationMaster container. + * This sets up the launch environment, java options, and the command for launching the AM. + */ + protected def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) + : ContainerLaunchContext = { + logInfo("Setting up container launch context for our AM") + + val appId = newAppResponse.getApplicationId + val appStagingDir = getAppStagingDir(appId) + val localResources = prepareLocalResources(appStagingDir) + val launchEnv = setupLaunchEnv(appStagingDir) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(launchEnv) val javaOpts = ListBuffer[String]() // Add Xmx for AM memory - javaOpts += "-Xmx" + amMemory + "m" + javaOpts += "-Xmx" + args.amMemory + "m" val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) javaOpts += "-Djava.io.tmpdir=" + tmpDir @@ -361,8 +319,7 @@ trait ClientBase extends Logging { // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset // of cores on a node. - val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && - java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC")) + val useConcurrentAndIncrementalGC = launchEnv.get("SPARK_USE_CONC_INCR_GC").exists(_.toBoolean) if (useConcurrentAndIncrementalGC) { // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines @@ -380,6 +337,8 @@ trait ClientBase extends Logging { javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } + // Include driver-specific java options if we are launching a driver + val isLaunchingDriver = args.userClass != null if (isLaunchingDriver) { sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) @@ -397,19 +356,27 @@ trait ClientBase extends Logging { } else { Nil } + val userJar = + if (args.userJar != null) { + Seq("--jar", args.userJar) + } else { + Nil + } val amClass = if (isLaunchingDriver) { - classOf[ApplicationMaster].getName() + Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { - classOf[ApplicationMaster].getName().replace("ApplicationMaster", "ExecutorLauncher") + Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } + val userArgs = args.userArgs.flatMap { arg => + Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) + } val amArgs = - Seq(amClass) ++ userClass ++ - (if (args.userJar != null) Seq("--jar", args.userJar) else Nil) ++ - Seq("--executor-memory", args.executorMemory.toString, + Seq(amClass) ++ userClass ++ userJar ++ userArgs ++ + Seq( + "--executor-memory", args.executorMemory.toString, "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, - userArgsToString(args)) + "--num-executors ", args.numExecutors.toString) // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ @@ -418,41 +385,153 @@ trait ClientBase extends Logging { "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - logInfo("Yarn AM launch context:") - logInfo(s" user class: ${args.userClass}") - logInfo(s" env: $env") - logInfo(s" command: ${commands.mkString(" ")}") - // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList amContainer.setCommands(printableCommands) - setupSecurityToken(amContainer) + logDebug("===============================================================================") + logDebug("Yarn AM launch context:") + logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") + logDebug(" env:") + launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } + logDebug(" resources:") + localResources.foreach { case (k, v) => logDebug(s" $k -> $v")} + logDebug(" command:") + logDebug(s" ${printableCommands.mkString(" ")}") + logDebug("===============================================================================") // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + setupSecurityToken(amContainer) + UserGroupInformation.getCurrentUser().addCredentials(credentials) amContainer } + + /** + * Report the state of an application until it has exited, either successfully or + * due to some failure, then return the application state. + * + * @param appId ID of the application to monitor. + * @param returnOnRunning Whether to also return the application state when it is RUNNING. + * @param logApplicationReport Whether to log details of the application report every iteration. + * @return state of the application, one of FINISHED, FAILED, KILLED, and RUNNING. + */ + def monitorApplication( + appId: ApplicationId, + returnOnRunning: Boolean = false, + logApplicationReport: Boolean = true): YarnApplicationState = { + val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) + var lastState: YarnApplicationState = null + while (true) { + Thread.sleep(interval) + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + + if (logApplicationReport) { + logInfo(s"Application report for $appId (state: $state)") + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + val formattedDetails = details + .map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" } + .mkString("") + + // If DEBUG is enabled, log report details every iteration + // Otherwise, log them every time the application changes state + if (log.isDebugEnabled) { + logDebug(formattedDetails) + } else if (lastState != state) { + logInfo(formattedDetails) + } + } + + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return state + } + + if (returnOnRunning && state == YarnApplicationState.RUNNING) { + return state + } + + lastState = state + } + + // Never reached, but keeps compiler happy + throw new SparkException("While loop is depleted! This should never happen...") + } + + /** + * Submit an application to the ResourceManager and monitor its state. + * This continues until the application has exited for any reason. + */ + def run(): Unit = monitorApplication(submitApplication()) + + /* --------------------------------------------------------------------------------------- * + | Methods that cannot be implemented here due to API differences across hadoop versions | + * --------------------------------------------------------------------------------------- */ + + /** Submit an application running our ApplicationMaster to the ResourceManager. */ + def submitApplication(): ApplicationId + + /** Set up security tokens for launching our ApplicationMaster container. */ + protected def setupSecurityToken(containerContext: ContainerLaunchContext): Unit + + /** Get the application report from the ResourceManager for an application we have submitted. */ + protected def getApplicationReport(appId: ApplicationId): ApplicationReport + + /** + * Return the security token used by this client to communicate with the ApplicationMaster. + * If no security is enabled, the token returned by the report is null. + */ + protected def getClientToken(report: ApplicationReport): String } -object ClientBase extends Logging { +private[spark] object ClientBase extends Logging { + + // Alias for the Spark assembly jar and the user jar val SPARK_JAR: String = "__spark__.jar" val APP_JAR: String = "__app__.jar" + + // URI scheme that identifies local resources val LOCAL_SCHEME = "local" + + // Staging directory for any temporary jars or files + val SPARK_STAGING: String = ".sparkStaging" + + // Location of any user-defined Spark jars val CONF_SPARK_JAR = "spark.yarn.jar" - /** - * This is an internal config used to propagate the location of the user's jar file to the - * driver/executors. - */ + val ENV_SPARK_JAR = "SPARK_JAR" + + // Internal config to propagate the location of the user's jar to the driver/executors val CONF_SPARK_USER_JAR = "spark.yarn.user.jar" - /** - * This is an internal config used to propagate the list of extra jars to add to the classpath - * of executors. - */ + + // Internal config to propagate the locations of any extra jars to add to the classpath + // of the executors val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" - val ENV_SPARK_JAR = "SPARK_JAR" + + // Staging directory is private! -> rwx-------- + val STAGING_DIR_PERMISSION: FsPermission = + FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) + + // App files are world-wide readable and owner writable -> rw-r--r-- + val APP_FILE_PERMISSION: FsPermission = + FsPermission.createImmutable(Integer.parseInt("644", 8).toShort) /** * Find the user-defined Spark jar if configured, or return the jar containing this @@ -461,7 +540,7 @@ object ClientBase extends Logging { * This method first looks in the SparkConf object for the CONF_SPARK_JAR key, and in the * user environment if that is not found (for backwards compatibility). */ - def sparkJar(conf: SparkConf) = { + private def sparkJar(conf: SparkConf): String = { if (conf.contains(CONF_SPARK_JAR)) { conf.get(CONF_SPARK_JAR) } else if (System.getenv(ENV_SPARK_JAR) != null) { @@ -474,16 +553,22 @@ object ClientBase extends Logging { } } - def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) = { + /** + * Return the path to the given application's staging directory. + */ + private def getAppStagingDir(appId: ApplicationId): String = { + SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + } + + /** + * Populate the classpath entry in the given environment map with any application + * classpath specified through the Hadoop and Yarn configurations. + */ + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]): Unit = { val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf) for (c <- classPathElementsToAdd.flatten) { - YarnSparkHadoopUtil.addToEnvironment( - env, - Environment.CLASSPATH.name, - c.trim, - File.pathSeparator) + YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, c.trim) } - classPathElementsToAdd } private def getYarnAppClasspath(conf: Configuration): Option[Seq[String]] = @@ -519,7 +604,7 @@ object ClientBase extends Logging { /** * In Hadoop 0.23, the MR application classpath comes with the YARN application - * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. + * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. * So we need to use reflection to retrieve it. */ def getDefaultMRApplicationClasspath: Option[Seq[String]] = { @@ -545,8 +630,16 @@ object ClientBase extends Logging { triedDefault.toOption } - def populateClasspath(args: ClientArguments, conf: Configuration, sparkConf: SparkConf, - env: HashMap[String, String], extraClassPath: Option[String] = None) { + /** + * Populate the classpath entry in the given environment map. + * This includes the user jar, Spark jar, and any extra application jars. + */ + def populateClasspath( + args: ClientArguments, + conf: Configuration, + sparkConf: SparkConf, + env: HashMap[String, String], + extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach(addClasspathEntry(_, env)) addClasspathEntry(Environment.PWD.$(), env) @@ -554,36 +647,40 @@ object ClientBase extends Logging { if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { addUserClasspath(args, sparkConf, env) addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - ClientBase.populateHadoopClasspath(conf, env) + populateHadoopClasspath(conf, env) } else { addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - ClientBase.populateHadoopClasspath(conf, env) + populateHadoopClasspath(conf, env) addUserClasspath(args, sparkConf, env) } // Append all jar files under the working directory to the classpath. - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env); + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env) } /** * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly * to the classpath. */ - private def addUserClasspath(args: ClientArguments, conf: SparkConf, - env: HashMap[String, String]) = { - if (args != null) { - addFileToClasspath(args.userJar, APP_JAR, env) - if (args.addJars != null) { - args.addJars.split(",").foreach { case file: String => - addFileToClasspath(file, null, env) - } + private def addUserClasspath( + args: ClientArguments, + conf: SparkConf, + env: HashMap[String, String]): Unit = { + + // If `args` is not null, we are launching an AM container. + // Otherwise, we are launching executor containers. + val (mainJar, secondaryJars) = + if (args != null) { + (args.userJar, args.addJars) + } else { + (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null)) } - } else { - val userJar = conf.get(CONF_SPARK_USER_JAR, null) - addFileToClasspath(userJar, APP_JAR, env) - val cachedSecondaryJarLinks = conf.get(CONF_SPARK_YARN_SECONDARY_JARS, "").split(",") - cachedSecondaryJarLinks.foreach(jar => addFileToClasspath(jar, null, env)) + addFileToClasspath(mainJar, APP_JAR, env) + if (secondaryJars != null) { + secondaryJars.split(",").filter(_.nonEmpty).foreach { jar => + addFileToClasspath(jar, null, env) + } } } @@ -599,46 +696,44 @@ object ClientBase extends Logging { * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ - private def addFileToClasspath(path: String, fileName: String, - env: HashMap[String, String]) : Unit = { + private def addFileToClasspath( + path: String, + fileName: String, + env: HashMap[String, String]): Unit = { if (path != null) { scala.util.control.Exception.ignoring(classOf[URISyntaxException]) { - val localPath = getLocalPath(path) - if (localPath != null) { - addClasspathEntry(localPath, env) + val uri = new URI(path) + if (uri.getScheme == LOCAL_SCHEME) { + addClasspathEntry(uri.getPath, env) return } } } if (fileName != null) { - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env); + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env) } } /** - * Returns the local path if the URI is a "local:" URI, or null otherwise. + * Add the given path to the classpath entry of the given environment map. + * If the classpath is already set, this appends the new path to the existing classpath. */ - private def getLocalPath(resource: String): String = { - val uri = new URI(resource) - if (LOCAL_SCHEME.equals(uri.getScheme())) { - return uri.getPath() - } - null - } - - private def addClasspathEntry(path: String, env: HashMap[String, String]) = - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, - File.pathSeparator) + private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = + YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) /** * Get the list of namenodes the user may access. */ - private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { - sparkConf.get("spark.yarn.access.namenodes", "").split(",").map(_.trim()).filter(!_.isEmpty) - .map(new Path(_)).toSet + def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "") + .split(",") + .map(_.trim()) + .filter(!_.isEmpty) + .map(new Path(_)) + .toSet } - private[yarn] def getTokenRenewer(conf: Configuration): String = { + def getTokenRenewer(conf: Configuration): String = { val delegTokenRenewer = Master.getMasterPrincipal(conf) logDebug("delegation token renewer is: " + delegTokenRenewer) if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { @@ -652,17 +747,54 @@ object ClientBase extends Logging { /** * Obtains tokens for the namenodes passed in and adds them to the credentials. */ - private[yarn] def obtainTokensForNamenodes(paths: Set[Path], conf: Configuration, - creds: Credentials) { + def obtainTokensForNamenodes( + paths: Set[Path], + conf: Configuration, + creds: Credentials): Unit = { if (UserGroupInformation.isSecurityEnabled()) { val delegTokenRenewer = getTokenRenewer(conf) + paths.foreach { dst => + val dstFs = dst.getFileSystem(conf) + logDebug("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } - paths.foreach { - dst => - val dstFs = dst.getFileSystem(conf) - logDebug("getting token for namenode: " + dst) - dstFs.addDelegationTokens(delegTokenRenewer, creds) + /** + * Return whether the two file systems are the same. + */ + private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + if (srcUri.getScheme() == null) { + return false + } + if (!srcUri.getScheme().equals(dstUri.getScheme())) { + return false + } + var srcHost = srcUri.getHost() + var dstHost = dstUri.getHost() + if ((srcHost != null) && (dstHost != null)) { + try { + srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() + dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() + } catch { + case e: UnknownHostException => + return false } + if (!srcHost.equals(dstHost)) { + return false + } + } else if (srcHost == null && dstHost != null) { + return false + } else if (srcHost != null && dstHost == null) { + return false + } + if (srcUri.getPort() != dstUri.getPort()) { + false + } else { + true } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 9b7f1fca96c6d..c592ecfdfce06 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -19,29 +19,24 @@ package org.apache.spark.deploy.yarn import java.net.URI +import scala.collection.mutable.{HashMap, LinkedHashMap, Map} + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.FsAction -import org.apache.hadoop.yarn.api.records.LocalResource -import org.apache.hadoop.yarn.api.records.LocalResourceVisibility -import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} -import org.apache.spark.Logging - -import scala.collection.mutable.HashMap -import scala.collection.mutable.LinkedHashMap -import scala.collection.mutable.Map - +import org.apache.spark.Logging /** Client side methods to setup the Hadoop distributed cache */ -class ClientDistributedCacheManager() extends Logging { - private val distCacheFiles: Map[String, Tuple3[String, String, String]] = - LinkedHashMap[String, Tuple3[String, String, String]]() - private val distCacheArchives: Map[String, Tuple3[String, String, String]] = - LinkedHashMap[String, Tuple3[String, String, String]]() +private[spark] class ClientDistributedCacheManager() extends Logging { + + // Mappings from remote URI to (file status, modification time, visibility) + private val distCacheFiles: Map[String, (String, String, String)] = + LinkedHashMap[String, (String, String, String)]() + private val distCacheArchives: Map[String, (String, String, String)] = + LinkedHashMap[String, (String, String, String)]() /** @@ -68,9 +63,9 @@ class ClientDistributedCacheManager() extends Logging { resourceType: LocalResourceType, link: String, statCache: Map[URI, FileStatus], - appMasterOnly: Boolean = false) = { + appMasterOnly: Boolean = false): Unit = { val destStatus = fs.getFileStatus(destPath) - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + val amJarRsrc = Records.newRecord(classOf[LocalResource]) amJarRsrc.setType(resourceType) val visibility = getVisibility(conf, destPath.toUri(), statCache) amJarRsrc.setVisibility(visibility) @@ -80,7 +75,7 @@ class ClientDistributedCacheManager() extends Logging { if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - if (appMasterOnly == false) { + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { @@ -95,12 +90,10 @@ class ClientDistributedCacheManager() extends Logging { /** * Adds the necessary cache file env variables to the env passed in - * @param env */ - def setDistFilesEnv(env: Map[String, String]) = { + def setDistFilesEnv(env: Map[String, String]): Unit = { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 - if (keys.size > 0) { env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = @@ -114,12 +107,10 @@ class ClientDistributedCacheManager() extends Logging { /** * Adds the necessary cache archive env variables to the env passed in - * @param env */ - def setDistArchivesEnv(env: Map[String, String]) = { + def setDistArchivesEnv(env: Map[String, String]): Unit = { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 - if (keys.size > 0) { env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = @@ -133,25 +124,21 @@ class ClientDistributedCacheManager() extends Logging { /** * Returns the local resource visibility depending on the cache file permissions - * @param conf - * @param uri - * @param statCache * @return LocalResourceVisibility */ - def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): - LocalResourceVisibility = { + def getVisibility( + conf: Configuration, + uri: URI, + statCache: Map[URI, FileStatus]): LocalResourceVisibility = { if (isPublic(conf, uri, statCache)) { - return LocalResourceVisibility.PUBLIC - } - LocalResourceVisibility.PRIVATE + LocalResourceVisibility.PUBLIC + } else { + LocalResourceVisibility.PRIVATE + } } /** - * Returns a boolean to denote whether a cache file is visible to all(public) - * or not - * @param conf - * @param uri - * @param statCache + * Returns a boolean to denote whether a cache file is visible to all (public) * @return true if the path in the uri is visible to all, false otherwise */ def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = { @@ -167,13 +154,12 @@ class ClientDistributedCacheManager() extends Logging { /** * Returns true if all ancestors of the specified path have the 'execute' * permission set for all users (i.e. that other users can traverse - * the directory heirarchy to the given path) - * @param fs - * @param path - * @param statCache + * the directory hierarchy to the given path) * @return true if all ancestors have the 'execute' permission set for all users */ - def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path, + def ancestorsHaveExecutePermissions( + fs: FileSystem, + path: Path, statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { @@ -187,32 +173,25 @@ class ClientDistributedCacheManager() extends Logging { } /** - * Checks for a given path whether the Other permissions on it + * Checks for a given path whether the Other permissions on it * imply the permission in the passed FsAction - * @param fs - * @param path - * @param action - * @param statCache * @return true if the path in the uri is visible to all, false otherwise */ - def checkPermissionOfOther(fs: FileSystem, path: Path, - action: FsAction, statCache: Map[URI, FileStatus]): Boolean = { + def checkPermissionOfOther( + fs: FileSystem, + path: Path, + action: FsAction, + statCache: Map[URI, FileStatus]): Boolean = { val status = getFileStatus(fs, path.toUri(), statCache) val perms = status.getPermission() val otherAction = perms.getOtherAction() - if (otherAction.implies(action)) { - return true - } - false + otherAction.implies(action) } /** - * Checks to see if the given uri exists in the cache, if it does it + * Checks to see if the given uri exists in the cache, if it does it * returns the existing FileStatus, otherwise it stats the uri, stores * it in the cache, and returns the FileStatus. - * @param fs - * @param uri - * @param statCache * @return FileStatus */ def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index f56f72cafe50e..bbbf615510762 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.yarn -import java.io.File import java.net.URI import scala.collection.JavaConversions._ @@ -128,9 +127,9 @@ trait ExecutorRunnableUtil extends Logging { localResources: HashMap[String, LocalResource], timestamp: String, size: String, - vis: String) = { + vis: String): Unit = { val uri = new URI(file) - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + val amJarRsrc = Records.newRecord(classOf[LocalResource]) amJarRsrc.setType(rtype) amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) @@ -175,14 +174,17 @@ trait ExecutorRunnableUtil extends Logging { ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp) sparkConf.getExecutorEnv.foreach { case (key, value) => - YarnSparkHadoopUtil.addToEnvironment(env, key, value, File.pathSeparator) + // This assumes each executor environment variable set here is a path + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) } // Keep this for backwards compatibility but users should move to the config - YarnSparkHadoopUtil.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"), - File.pathSeparator) + sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => + YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) + } - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } env } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4a33e34c3bfc7..0b712c201904a 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.lang.{Boolean => JBoolean} +import java.io.File import java.util.{Collections, Set => JSet} import java.util.regex.Matcher import java.util.regex.Pattern @@ -29,14 +30,12 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.util.StringInterner import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SecurityManager, SparkConf, SparkContext} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -100,30 +99,26 @@ object YarnSparkHadoopUtil { private val hostToRack = new ConcurrentHashMap[String, String]() private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() - def addToEnvironment( - env: HashMap[String, String], - variable: String, - value: String, - classPathSeparator: String) = { - var envVariable = "" - if (env.get(variable) == None) { - envVariable = value - } else { - envVariable = env.get(variable).get + classPathSeparator + value - } - env put (StringInterner.weakIntern(variable), StringInterner.weakIntern(envVariable)) + /** + * Add a path variable to the given environment map. + * If the map already contains this key, append the value to the existing value instead. + */ + def addPathToEnvironment(env: HashMap[String, String], key: String, value: String): Unit = { + val newValue = if (env.contains(key)) { env(key) + File.pathSeparator + value } else value + env.put(key, newValue) } - def setEnvFromInputString( - env: HashMap[String, String], - envString: String, - classPathSeparator: String) = { - if (envString != null && envString.length() > 0) { - var childEnvs = envString.split(",") - var p = Pattern.compile(getEnvironmentVariableRegex()) + /** + * Set zero or more environment variables specified by the given input string. + * The input string is expected to take the form "KEY1=VAL1,KEY2=VAL2,KEY3=VAL3". + */ + def setEnvFromInputString(env: HashMap[String, String], inputString: String): Unit = { + if (inputString != null && inputString.length() > 0) { + val childEnvs = inputString.split(",") + val p = Pattern.compile(environmentVariableRegex) for (cEnv <- childEnvs) { - var parts = cEnv.split("=") // split on '=' - var m = p.matcher(parts(1)) + val parts = cEnv.split("=") // split on '=' + val m = p.matcher(parts(1)) val sb = new StringBuffer while (m.find()) { val variable = m.group(1) @@ -131,8 +126,7 @@ object YarnSparkHadoopUtil { if (env.get(variable) != None) { replace = env.get(variable).get } else { - // if this key is not configured for the child .. get it - // from the env + // if this key is not configured for the child .. get it from the env replace = System.getenv(variable) if (replace == null) { // the env key is note present anywhere .. simply set it @@ -142,14 +136,15 @@ object YarnSparkHadoopUtil { m.appendReplacement(sb, Matcher.quoteReplacement(replace)) } m.appendTail(sb) - addToEnvironment(env, parts(0), sb.toString(), classPathSeparator) + // This treats the environment variable as path variable delimited by `File.pathSeparator` + // This is kept for backward compatibility and consistency with Hadoop's behavior + addPathToEnvironment(env, parts(0), sb.toString) } } } - private def getEnvironmentVariableRegex() : String = { - val osName = System.getProperty("os.name") - if (osName startsWith "Windows") { + private val environmentVariableRegex: String = { + if (Utils.isWindows) { "%([A-Za-z_][A-Za-z0-9_]*?)%" } else { "\\$([A-Za-z_][A-Za-z0-9_]*)" @@ -181,14 +176,14 @@ object YarnSparkHadoopUtil { } } - private[spark] def lookupRack(conf: Configuration, host: String): String = { + def lookupRack(conf: Configuration, host: String): String = { if (!hostToRack.contains(host)) { populateRackInfo(conf, host) } hostToRack.get(host) } - private[spark] def populateRackInfo(conf: Configuration, hostname: String) { + def populateRackInfo(conf: Configuration, hostname: String) { Utils.checkHost(hostname) if (!hostToRack.containsKey(hostname)) { @@ -212,8 +207,8 @@ object YarnSparkHadoopUtil { } } - private[spark] def getApplicationAclsForYarn(securityMgr: SecurityManager): - Map[ApplicationAccessType, String] = { + def getApplicationAclsForYarn(securityMgr: SecurityManager) + : Map[ApplicationAccessType, String] = { Map[ApplicationAccessType, String] ( ApplicationAccessType.VIEW_APP -> securityMgr.getViewAcls, ApplicationAccessType.MODIFY_APP -> securityMgr.getModifyAcls diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 6aa6475fe4a18..200a30899290b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.scheduler.TaskSchedulerImpl import scala.collection.mutable.ArrayBuffer @@ -34,115 +34,120 @@ private[spark] class YarnClientSchedulerBackend( minRegisteredRatio = 0.8 } - var client: Client = null - var appId: ApplicationId = null - var checkerThread: Thread = null - var stopping: Boolean = false - var totalExpectedExecutors = 0 - - private[spark] def addArg(optionName: String, envVar: String, sysProp: String, - arrayBuf: ArrayBuffer[String]) { - if (System.getenv(envVar) != null) { - arrayBuf += (optionName, System.getenv(envVar)) - } else if (sc.getConf.contains(sysProp)) { - arrayBuf += (optionName, sc.getConf.get(sysProp)) - } - } + private var client: Client = null + private var appId: ApplicationId = null + private var stopping: Boolean = false + private var totalExpectedExecutors = 0 + /** + * Create a Yarn client to submit an application to the ResourceManager. + * This waits until the application is running. + */ override def start() { super.start() - val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIHostPort) } val argsArrayBuf = new ArrayBuffer[String]() - argsArrayBuf += ( - "--args", hostport - ) - - // process any optional arguments, given either as environment variables - // or system properties. use the defaults already defined in ClientArguments - // if things aren't specified. system properties override environment - // variables. - List(("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), - ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), - ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), - ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), - ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), - ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--name", "SPARK_YARN_APP_NAME", "spark.app.name")) - .foreach { case (optName, envVar, sysProp) => addArg(optName, envVar, sysProp, argsArrayBuf) } - - logDebug("ClientArguments called with: " + argsArrayBuf) + argsArrayBuf += ("--arg", hostport) + argsArrayBuf ++= getExtraClientArguments + + logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) val args = new ClientArguments(argsArrayBuf.toArray, conf) totalExpectedExecutors = args.numExecutors client = new Client(args, conf) - appId = client.runApp() - waitForApp() - checkerThread = yarnApplicationStateCheckerThread() + appId = client.submitApplication() + waitForApplication() + asyncMonitorApplication() } - def waitForApp() { - - // TODO : need a better way to find out whether the executors are ready or not - // maybe by resource usage report? - while(true) { - val report = client.getApplicationReport(appId) - - logInfo("Application report from ASM: \n" + - "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + - "\t appStartTime: " + report.getStartTime() + "\n" + - "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + /** + * Return any extra command line arguments to be passed to Client provided in the form of + * environment variables or Spark properties. + */ + private def getExtraClientArguments: Seq[String] = { + val extraArgs = new ArrayBuffer[String] + val optionTuples = // List of (target Client argument, environment variable, Spark property) + List( + ("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), + ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), + ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), + ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), + ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), + ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), + ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), + ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--name", "SPARK_YARN_APP_NAME", "spark.app.name") ) - - // Ready to go, or already gone. - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.RUNNING) { - return - } else if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - throw new SparkException("Yarn application already ended," + - "might be killed or not able to launch application master.") + optionTuples.foreach { case (optionName, envVar, sparkProp) => + if (System.getenv(envVar) != null) { + extraArgs += (optionName, System.getenv(envVar)) + } else if (sc.getConf.contains(sparkProp)) { + extraArgs += (optionName, sc.getConf.get(sparkProp)) } + } + extraArgs + } - Thread.sleep(1000) + /** + * Report the state of the application until it is running. + * If the application has finished, failed or been killed in the process, throw an exception. + * This assumes both `client` and `appId` have already been set. + */ + private def waitForApplication(): Unit = { + assert(client != null && appId != null, "Application has not been submitted yet!") + val state = client.monitorApplication(appId, returnOnRunning = true) // blocking + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + throw new SparkException("Yarn application has already ended! " + + "It might have been killed or unable to launch application master.") + } + if (state == YarnApplicationState.RUNNING) { + logInfo(s"Application $appId has started running.") } } - private def yarnApplicationStateCheckerThread(): Thread = { + /** + * Monitor the application state in a separate thread. + * If the application has exited for any reason, stop the SparkContext. + * This assumes both `client` and `appId` have already been set. + */ + private def asyncMonitorApplication(): Unit = { + assert(client != null && appId != null, "Application has not been submitted yet!") val t = new Thread { override def run() { while (!stopping) { val report = client.getApplicationReport(appId) val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.KILLED - || state == YarnApplicationState.FAILED) { - logError(s"Yarn application already ended: $state") + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.KILLED || + state == YarnApplicationState.FAILED) { + logError(s"Yarn application has already exited with state $state!") sc.stop() stopping = true } Thread.sleep(1000L) } - checkerThread = null Thread.currentThread().interrupt() } } - t.setName("Yarn Application State Checker") + t.setName("Yarn application state monitor") t.setDaemon(true) t.start() - t } + /** + * Stop the scheduler. This assumes `start()` has already been called. + */ override def stop() { + assert(client != null, "Attempted to stop this scheduler before starting it!") stopping = true super.stop() - client.stop + client.stop() logInfo("Stopped") } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index c3b7a2c8f02e5..9bd916100dd2c 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse -import org.apache.hadoop.yarn.api.records.ContainerLaunchContext +import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -90,7 +90,7 @@ class ClientBaseSuite extends FunSuite with Matchers { val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - ClientBase.populateClasspath(args, conf, sparkConf, env, None) + ClientBase.populateClasspath(args, conf, sparkConf, env) val cp = env("CLASSPATH").split(File.pathSeparator) s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => @@ -114,10 +114,10 @@ class ClientBaseSuite extends FunSuite with Matchers { val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) val client = spy(new DummyClient(args, conf, sparkConf, yarnConf)) - doReturn(new Path("/")).when(client).copyRemoteFile(any(classOf[Path]), + doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), any(classOf[Path]), anyShort(), anyBoolean()) - var tempDir = Files.createTempDir(); + val tempDir = Files.createTempDir() try { client.prepareLocalResources(tempDir.getAbsolutePath()) sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) @@ -247,13 +247,13 @@ class ClientBaseSuite extends FunSuite with Matchers { private class DummyClient( val args: ClientArguments, - val conf: Configuration, + val hadoopConf: Configuration, val sparkConf: SparkConf, val yarnConf: YarnConfiguration) extends ClientBase { - - override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = - throw new UnsupportedOperationException() - + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = ??? + override def submitApplication(): ApplicationId = ??? + override def getApplicationReport(appId: ApplicationId): ApplicationReport = ??? + override def getClientToken(report: ApplicationReport): String = ??? } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 82e45e3e7ad54..0b43e6ee20538 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,11 +21,9 @@ import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer -import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.client.api.YarnClient +import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf} @@ -34,128 +32,98 @@ import org.apache.spark.deploy.SparkHadoopUtil /** * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's stable API. */ -class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: SparkConf) +private[spark] class Client( + val args: ClientArguments, + val hadoopConf: Configuration, + val sparkConf: SparkConf) extends ClientBase with Logging { - val yarnClient = YarnClient.createYarnClient - def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - val args = clientArgs - val conf = hadoopConf - val sparkConf = spConf - var rpc: YarnRPC = YarnRPC.create(conf) - val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - - def runApp(): ApplicationId = { - validateArgs() - // Initialize and start the client service. + val yarnClient = YarnClient.createYarnClient + val yarnConf = new YarnConfiguration(hadoopConf) + + def stop(): Unit = yarnClient.stop() + + /* ------------------------------------------------------------------------------------- * + | The following methods have much in common in the stable and alpha versions of Client, | + | but cannot be implemented in the parent trait due to subtle API differences across | + | hadoop versions. | + * ------------------------------------------------------------------------------------- */ + + /** + * Submit an application running our ApplicationMaster to the ResourceManager. + * + * The stable Yarn API provides a convenience method (YarnClient#createApplication) for + * creating applications and setting up the application submission context. This was not + * available in the alpha API. + */ + override def submitApplication(): ApplicationId = { yarnClient.init(yarnConf) yarnClient.start() - // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers). - logClusterResourceDetails() - - // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM) - // interface). + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) - // Get a new client application. + // Get a new application from our RM val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() val appId = newAppResponse.getApplicationId() + // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) - // Set up resource and environment variables. - val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(localResources, appStagingDir) - val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv) + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(newApp, containerContext) - // Set up an application submission context. - val appContext = newApp.getApplicationSubmissionContext() - appContext.setApplicationName(args.appName) - appContext.setQueue(args.amQueue) - appContext.setAMContainerSpec(amContainer) - appContext.setApplicationType("SPARK") - - // Memory for the ApplicationMaster. - val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] - memoryResource.setMemory(args.amMemory + memoryOverhead) - appContext.setResource(memoryResource) - - // Finally, submit and monitor the application. - submitApp(appContext) + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + yarnClient.submitApplication(appContext) appId } - def run() { - val appId = runApp() - monitorApplication(appId) - } - - def logClusterResourceDetails() { - val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics - logInfo("Got cluster metric info from ResourceManager, number of NodeManagers: " + - clusterMetrics.getNumNodeManagers) + /** + * Set up the context for submitting our ApplicationMaster. + * This uses the YarnClientApplication not available in the Yarn alpha API. + */ + def createApplicationSubmissionContext( + newApp: YarnClientApplication, + containerContext: ContainerLaunchContext): ApplicationSubmissionContext = { + val appContext = newApp.getApplicationSubmissionContext + appContext.setApplicationName(args.appName) + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(containerContext) + appContext.setApplicationType("SPARK") + val capability = Records.newRecord(classOf[Resource]) + capability.setMemory(args.amMemory + amMemoryOverhead) + appContext.setResource(capability) + appContext } - def setupSecurityToken(amContainer: ContainerLaunchContext) = { - // Setup security tokens. - val dob = new DataOutputBuffer() + /** Set up security tokens for launching our ApplicationMaster container. */ + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { + val dob = new DataOutputBuffer credentials.writeTokenStorageToStream(dob) - amContainer.setTokens(ByteBuffer.wrap(dob.getData())) + amContainer.setTokens(ByteBuffer.wrap(dob.getData)) } - def submitApp(appContext: ApplicationSubmissionContext) = { - // Submit the application to the applications manager. - logInfo("Submitting application to ResourceManager") - yarnClient.submitApplication(appContext) - } + /** Get the application report from the ResourceManager for an application we have submitted. */ + override def getApplicationReport(appId: ApplicationId): ApplicationReport = + yarnClient.getApplicationReport(appId) - def getApplicationReport(appId: ApplicationId) = - yarnClient.getApplicationReport(appId) - - def stop = yarnClient.stop - - def monitorApplication(appId: ApplicationId): Boolean = { - val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) - - while (true) { - Thread.sleep(interval) - val report = yarnClient.getApplicationReport(appId) - - logInfo("Application report from ResourceManager: \n" + - "\t application identifier: " + appId.toString() + "\n" + - "\t appId: " + appId.getId() + "\n" + - "\t clientToAMToken: " + report.getClientToAMToken() + "\n" + - "\t appDiagnostics: " + report.getDiagnostics() + "\n" + - "\t appMasterHost: " + report.getHost() + "\n" + - "\t appQueue: " + report.getQueue() + "\n" + - "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + - "\t appStartTime: " + report.getStartTime() + "\n" + - "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + - "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + - "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + - "\t appUser: " + report.getUser() - ) - - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - return true - } - } - true - } + /** + * Return the security token used by this client to communicate with the ApplicationMaster. + * If no security is enabled, the token returned by the report is null. + */ + override def getClientToken(report: ApplicationReport): String = + Option(report.getClientToAMToken).map(_.toString).getOrElse("") } object Client { - def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a " + @@ -163,22 +131,19 @@ object Client { } // Set an env variable indicating we are running in YARN mode. - // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes - - // see Client#setupLaunchEnv(). + // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") - val sparkConf = new SparkConf() + val sparkConf = new SparkConf try { val args = new ClientArguments(argStrings, sparkConf) new Client(args, sparkConf).run() } catch { - case e: Exception => { + case e: Exception => Console.err.println(e.getMessage) System.exit(1) - } } System.exit(0) } - } From 11c10df825419372df61a8d23c51e8c3cc78047f Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 23 Sep 2014 11:40:14 -0500 Subject: [PATCH 064/315] [SPARK-3304] [YARN] ApplicationMaster's Finish status is wrong when uncaught exception is thrown from ReporterThread Author: Kousuke Saruta Closes #2198 from sarutak/SPARK-3304 and squashes the following commits: 2696237 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 5b80363 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 4eb0a3e [Kousuke Saruta] Remoed the description about spark.yarn.scheduler.reporterThread.maxFailure 9741597 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 f7538d4 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 358ef8d [Kousuke Saruta] Merge branch 'SPARK-3304' of github.com:sarutak/spark into SPARK-3304 0d138c6 [Kousuke Saruta] Revert "tmp" f8da10a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 b6e9879 [Kousuke Saruta] tmp 8d256ed [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 13b2652 [Kousuke Saruta] Merge branch 'SPARK-3304' of github.com:sarutak/spark into SPARK-3304 2711e15 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 c081f8e [Kousuke Saruta] Modified ApplicationMaster to handle exception in ReporterThread itself 0bbd3a6 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 a6982ad [Kousuke Saruta] Added ability handling uncaught exception thrown from Reporter thread --- .../spark/deploy/yarn/ApplicationMaster.scala | 66 +++++++++++++++---- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index cde5fff637a39..9050808157257 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -17,7 +17,10 @@ package org.apache.spark.deploy.yarn +import scala.util.control.NonFatal + import java.io.IOException +import java.lang.reflect.InvocationTargetException import java.net.Socket import java.util.concurrent.atomic.AtomicReference @@ -55,6 +58,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, @volatile private var finished = false @volatile private var finalStatus = FinalApplicationStatus.UNDEFINED + @volatile private var userClassThread: Thread = _ private var reporterThread: Thread = _ private var allocator: YarnAllocator = _ @@ -221,18 +225,48 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // must be <= expiryInterval / 2. val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + // The number of failures in a row until Reporter thread give up + val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) + val t = new Thread { override def run() { + var failureCount = 0 + while (!finished) { - checkNumExecutorsFailed() - if (!finished) { - logDebug("Sending progress") - allocator.allocateResources() - try { - Thread.sleep(interval) - } catch { - case e: InterruptedException => + try { + checkNumExecutorsFailed() + if (!finished) { + logDebug("Sending progress") + allocator.allocateResources() } + failureCount = 0 + } catch { + case e: Throwable => { + failureCount += 1 + if (!NonFatal(e) || failureCount >= reporterMaxFailures) { + logError("Exception was thrown from Reporter thread.", e) + finish(FinalApplicationStatus.FAILED, "Exception was thrown" + + s"${failureCount} time(s) from Reporter thread.") + + /** + * If exception is thrown from ReporterThread, + * interrupt user class to stop. + * Without this interrupting, if exception is + * thrown before allocating enough executors, + * YarnClusterScheduler waits until timeout even though + * we cannot allocate executors. + */ + logInfo("Interrupting user class to stop.") + userClassThread.interrupt + } else { + logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) + } + } + } + try { + Thread.sleep(interval) + } catch { + case e: InterruptedException => } } } @@ -355,7 +389,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) - val t = new Thread { + userClassThread = new Thread { override def run() { var status = FinalApplicationStatus.FAILED try { @@ -366,15 +400,23 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // Some apps have "System.exit(0)" at the end. The user thread will stop here unless // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. status = FinalApplicationStatus.SUCCEEDED + } catch { + case e: InvocationTargetException => { + e.getCause match { + case _: InterruptedException => { + // Reporter thread can interrupt to stop user class + } + } + } } finally { logDebug("Finishing main") } finalStatus = status } } - t.setName("Driver") - t.start() - t + userClassThread.setName("Driver") + userClassThread.start() + userClassThread } // Actor used to monitor the driver when running in client deploy mode. From 66bc0f2d675d06cdd48638f124a1ff32be2bf456 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 23 Sep 2014 11:45:44 -0700 Subject: [PATCH 065/315] [SPARK-3598][SQL]cast to timestamp should be the same as hive this patch fixes timestamp smaller than 0 and cast int as timestamp select cast(1000 as timestamp) from src limit 1; should return 1970-01-01 00:00:01, but we now take it as 1000 seconds. also, current implementation has bug when the time is before 1970-01-01 00:00:00. rxin marmbrus chenghao-intel Author: Daoyuan Wang Closes #2458 from adrian-wang/timestamp and squashes the following commits: 4274b1d [Daoyuan Wang] set test not related to timezone 1234f66 [Daoyuan Wang] fix timestamp smaller than 0 and cast int as timestamp --- .../spark/sql/catalyst/expressions/Cast.scala | 17 +++++++------ .../ExpressionEvaluationSuite.scala | 16 ++++++++----- ...cast #1-0-69fc614ccea92bbe39f4decc299edcc6 | 1 + ...cast #2-0-732ed232ac592c5e7f7c913a88874fd2 | 1 + ... cast #3-0-76ee270337f664b36cacfc6528ac109 | 1 + ...cast #4-0-732ed232ac592c5e7f7c913a88874fd2 | 1 + ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 1 + ...cast #6-0-6d2da5cfada03605834e38bc4075bc79 | 1 + ...cast #7-0-1d70654217035f8ce5f64344f4c5a80f | 1 + ...cast #8-0-6d2da5cfada03605834e38bc4075bc79 | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 24 +++++++++++++++++++ 11 files changed, 50 insertions(+), 15 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0379275121bf2..f626d09f037bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -86,15 +86,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000)) + buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0))) case LongType => - buildCast[Long](_, l => new Timestamp(l * 1000)) + buildCast[Long](_, l => new Timestamp(l)) case IntegerType => - buildCast[Int](_, i => new Timestamp(i * 1000)) + buildCast[Int](_, i => new Timestamp(i)) case ShortType => - buildCast[Short](_, s => new Timestamp(s * 1000)) + buildCast[Short](_, s => new Timestamp(s)) case ByteType => - buildCast[Byte](_, b => new Timestamp(b * 1000)) + buildCast[Byte](_, b => new Timestamp(b)) // TimestampWritable.decimalToTimestamp case DecimalType => buildCast[BigDecimal](_, d => decimalToTimestamp(d)) @@ -107,11 +107,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } private[this] def decimalToTimestamp(d: BigDecimal) = { - val seconds = d.longValue() + val seconds = Math.floor(d.toDouble).toLong val bd = (d - seconds) * 1000000000 val nanos = bd.intValue() - // Convert to millis val millis = seconds * 1000 val t = new Timestamp(millis) @@ -121,11 +120,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } // Timestamp to long, converting milliseconds to seconds - private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + private[this] def timestampToLong(ts: Timestamp) = Math.floor(ts.getTime / 1000.0).toLong private[this] def timestampToDouble(ts: Timestamp) = { // First part is the seconds since the beginning of time, followed by nanosecs. - ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 + Math.floor(ts.getTime / 1000.0).toLong + ts.getNanos.toDouble / 1000000000 } // Converts Timestamp to string according to Hive TimestampWritable convention diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index b961346dfc995..8b6721d5d8125 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -231,7 +231,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) checkEvaluation(Literal(1) cast LongType, 1) - checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1) + checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) + checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) + checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) @@ -242,11 +244,11 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5) + Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0) checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5) + Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) @@ -293,16 +295,18 @@ class ExpressionEvaluationSuite extends FunSuite { test("timestamp casting") { val millis = 15 * 1000 + 2 + val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part + val tss = new Timestamp(seconds) checkEvaluation(Cast(ts, ShortType), 15) checkEvaluation(Cast(ts, IntegerType), 15) checkEvaluation(Cast(ts, LongType), 15) checkEvaluation(Cast(ts, FloatType), 15.002f) checkEvaluation(Cast(ts, DoubleType), 15.002) - checkEvaluation(Cast(Cast(ts, ShortType), TimestampType), ts1) - checkEvaluation(Cast(Cast(ts, IntegerType), TimestampType), ts1) - checkEvaluation(Cast(Cast(ts, LongType), TimestampType), ts1) + checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) + checkEvaluation(Cast(Cast(tss, IntegerType), TimestampType), ts) + checkEvaluation(Cast(Cast(tss, LongType), TimestampType), ts) checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), diff --git a/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 b/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 new file mode 100644 index 0000000000000..8ebf695ba7d20 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 @@ -0,0 +1 @@ +0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 b/sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 new file mode 100644 index 0000000000000..5625e59da8873 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 @@ -0,0 +1 @@ +1.2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 b/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 new file mode 100644 index 0000000000000..5625e59da8873 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 @@ -0,0 +1 @@ +1.2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 new file mode 100644 index 0000000000000..27de46fdf22ac --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 @@ -0,0 +1 @@ +-0.0010000000000000009 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 b/sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 new file mode 100644 index 0000000000000..1d94c8a014fb4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 @@ -0,0 +1 @@ +-1.2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f new file mode 100644 index 0000000000000..3fbedf693b51d --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f @@ -0,0 +1 @@ +-2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 b/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 new file mode 100644 index 0000000000000..1d94c8a014fb4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 @@ -0,0 +1 @@ +-1.2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 56bcd95eab4bc..6fc891ba4cca5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -303,6 +303,30 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("case statements WITHOUT key #4", "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") + createQueryTest("timestamp cast #1", + "SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #2", + "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #3", + "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + + createQueryTest("timestamp cast #4", + "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #5", + "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #6", + "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #7", + "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + + createQueryTest("timestamp cast #8", + "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") .map { case Row(i: Int) => i } From 116016b481cecbd8ad6e9717d92f977a164a6653 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 23 Sep 2014 11:47:53 -0700 Subject: [PATCH 066/315] [SPARK-3582][SQL] not limit argument type for hive simple udf Since we have moved to `ConventionHelper`, it is quite easy to avoid call `javaClassToDataType` in hive simple udf. This will solve SPARK-3582. Author: Daoyuan Wang Closes #2506 from adrian-wang/spark3582 and squashes the following commits: 450c28e [Daoyuan Wang] not limit argument type for hive simple udf --- .../spark/sql/hive/HiveInspectors.scala | 4 ++-- .../org/apache/spark/sql/hive/hiveUdfs.scala | 22 ++----------------- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 943bbaa8ce25e..fa889ec104c6e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -137,7 +137,7 @@ private[hive] trait HiveInspectors { /** Converts native catalyst types to the types expected by Hive */ def wrap(a: Any): AnyRef = a match { - case s: String => new hadoopIo.Text(s) // TODO why should be Text? + case s: String => s: java.lang.String case i: Int => i: java.lang.Integer case b: Boolean => b: java.lang.Boolean case f: Float => f: java.lang.Float @@ -145,7 +145,7 @@ private[hive] trait HiveInspectors { case l: Long => l: java.lang.Long case l: Short => l: java.lang.Short case l: Byte => l: java.lang.Byte - case b: BigDecimal => b.bigDecimal + case b: BigDecimal => new HiveDecimal(b.underlying()) case b: Array[Byte] => b case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 19ff3b66ad7ed..68944ed4ef21d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -51,19 +51,7 @@ private[hive] abstract class HiveFunctionRegistry val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] - val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) - - val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) - - HiveSimpleUdf( - functionClassName, - children.zip(expectedDataTypes).map { - case (e, NullType) => e - case (e, t) if (e.dataType == t) => e - case (e, t) => Cast(e, t) - } - ) + HiveSimpleUdf(functionClassName, children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdf(functionClassName, children) } else if ( @@ -117,15 +105,9 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ @transient lazy val dataType = javaClassToDataType(method.getReturnType) - def catalystToHive(value: Any): Object = value match { - // TODO need more types here? or can we use wrap() - case bd: BigDecimal => new HiveDecimal(bd.underlying()) - case d => d.asInstanceOf[Object] - } - // TODO: Finish input output types. override def eval(input: Row): Any = { - val evaluatedChildren = children.map(c => catalystToHive(c.eval(input))) + val evaluatedChildren = children.map(c => wrap(c.eval(input))) unwrap(FunctionRegistry.invoke(method, function, conversionHelper .convertIfNecessary(evaluatedChildren: _*): _*)) From 3b8eefa9b843c7f1e0e8dda6023272bc9f011c5c Mon Sep 17 00:00:00 2001 From: ravipesala Date: Tue, 23 Sep 2014 11:52:13 -0700 Subject: [PATCH 067/315] [SPARK-3536][SQL] SELECT on empty parquet table throws exception It returns null metadata from parquet if querying on empty parquet file while calculating splits.So added null check and returns the empty splits. Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala Closes #2456 from ravipesala/SPARK-3536 and squashes the following commits: 1e81a50 [ravipesala] Fixed the issue when querying on empty parquet file. --- .../spark/sql/parquet/ParquetTableOperations.scala | 7 +++++-- .../org/apache/spark/sql/parquet/ParquetQuerySuite.scala | 9 +++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a5a5d139a65cb..d39e31a7fa195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -427,11 +427,15 @@ private[parquet] class FilteringParquetRowInputFormat s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + s" minSplitSize = $minSplitSize") } - + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) getGlobalMetaData.setAccessible(true) val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] + // if parquet file is empty, return empty splits. + if (globalMetaData == null) { + return splits + } val readContext = getReadSupport(configuration).init( new InitContext(configuration, @@ -442,7 +446,6 @@ private[parquet] class FilteringParquetRowInputFormat classOf[ParquetInputFormat[_]].getDeclaredMethods.find(_.getName == "generateSplits").get generateSplits.setAccessible(true) - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] for (footer <- footers) { val fs = footer.getFile.getFileSystem(configuration) val file = footer.getFile diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08f7358446b29..07adf731405af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -789,4 +789,13 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } + + test("Querying on empty parquet throws exception (SPARK-3536)") { + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + createParquetFile[TestRDDEntry](tmpdir.toString()).registerTempTable("tmpemptytable") + val result1 = sql("SELECT * FROM tmpemptytable").collect() + assert(result1.size === 0) + Utils.deleteRecursively(tmpdir) + } } From e73b48ace0a7e0f249221240140235d33eeac36b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 23 Sep 2014 11:58:05 -0700 Subject: [PATCH 068/315] SPARK-2745 [STREAMING] Add Java friendly methods to Duration class tdas is this what you had in mind for this JIRA? I saw this one and thought it would be easy to take care of, and helpful as I use streaming from Java. I could do the same for `Time`? Happy to do so. Author: Sean Owen Closes #2403 from srowen/SPARK-2745 and squashes the following commits: 5a9e706 [Sean Owen] Change "Duration" to "Durations" to avoid changing Duration case class API bda301c [Sean Owen] Just delegate to Scala binary operator syntax to avoid scalastyle warning 7dde949 [Sean Owen] Disable scalastyle for false positives. Add Java static factory methods seconds(), minutes() to Duration. Add Java-friendly methods to Time too, and unit tests. Remove unnecessary math.floor from Time.floor() 4dee32e [Sean Owen] Add named methods to Duration in parallel to symbolic methods for Java-friendliness. Also add unit tests for Duration, in Scala and Java. --- .../org/apache/spark/streaming/Duration.scala | 39 ++++++ .../org/apache/spark/streaming/Time.scala | 20 +++- .../spark/streaming/JavaDurationSuite.java | 84 +++++++++++++ .../apache/spark/streaming/JavaTimeSuite.java | 63 ++++++++++ .../spark/streaming/DurationSuite.scala | 110 +++++++++++++++++ .../apache/spark/streaming/TimeSuite.scala | 111 ++++++++++++++++++ 6 files changed, 425 insertions(+), 2 deletions(-) create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index 6bf275f5afcb2..a0d8fb5ab93ec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -37,6 +37,25 @@ case class Duration (private val millis: Long) { def / (that: Duration): Double = millis.toDouble / that.millis.toDouble + // Java-friendlier versions of the above. + + def less(that: Duration): Boolean = this < that + + def lessEq(that: Duration): Boolean = this <= that + + def greater(that: Duration): Boolean = this > that + + def greaterEq(that: Duration): Boolean = this >= that + + def plus(that: Duration): Duration = this + that + + def minus(that: Duration): Duration = this - that + + def times(times: Int): Duration = this * times + + def div(that: Duration): Double = this / that + + def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) @@ -80,4 +99,24 @@ object Minutes { def apply(minutes: Long) = new Duration(minutes * 60000) } +// Java-friendlier versions of the objects above. +// Named "Durations" instead of "Duration" to avoid changing the case class's implied API. + +object Durations { + + /** + * @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds. + */ + def milliseconds(milliseconds: Long) = Milliseconds(milliseconds) + /** + * @return [[org.apache.spark.streaming.Duration]] representing given number of seconds. + */ + def seconds(seconds: Long) = Seconds(seconds) + + /** + * @return [[org.apache.spark.streaming.Duration]] representing given number of minutes. + */ + def minutes(minutes: Long) = Minutes(minutes) + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala index 37b3b28fa01cb..42c49678d24f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala @@ -41,10 +41,26 @@ case class Time(private val millis: Long) { def - (that: Duration): Time = new Time(millis - that.milliseconds) + // Java-friendlier versions of the above. + + def less(that: Time): Boolean = this < that + + def lessEq(that: Time): Boolean = this <= that + + def greater(that: Time): Boolean = this > that + + def greaterEq(that: Time): Boolean = this >= that + + def plus(that: Duration): Time = this + that + + def minus(that: Time): Duration = this - that + + def minus(that: Duration): Time = this - that + + def floor(that: Duration): Time = { val t = that.milliseconds - val m = math.floor(this.millis / t).toLong - new Time(m * t) + new Time((this.millis / t) * t) } def isMultipleOf(that: Duration): Boolean = diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java new file mode 100644 index 0000000000000..76425fe2aa2d3 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaDurationSuite { + + // Just testing the methods that are specially exposed for Java. + // This does not repeat all tests found in the Scala suite. + + @Test + public void testLess() { + Assert.assertTrue(new Duration(999).less(new Duration(1000))); + } + + @Test + public void testLessEq() { + Assert.assertTrue(new Duration(1000).lessEq(new Duration(1000))); + } + + @Test + public void testGreater() { + Assert.assertTrue(new Duration(1000).greater(new Duration(999))); + } + + @Test + public void testGreaterEq() { + Assert.assertTrue(new Duration(1000).greaterEq(new Duration(1000))); + } + + @Test + public void testPlus() { + Assert.assertEquals(new Duration(1100), new Duration(1000).plus(new Duration(100))); + } + + @Test + public void testMinus() { + Assert.assertEquals(new Duration(900), new Duration(1000).minus(new Duration(100))); + } + + @Test + public void testTimes() { + Assert.assertEquals(new Duration(200), new Duration(100).times(2)); + } + + @Test + public void testDiv() { + Assert.assertEquals(200.0, new Duration(1000).div(new Duration(5)), 1.0e-12); + } + + @Test + public void testMilliseconds() { + Assert.assertEquals(new Duration(100), Durations.milliseconds(100)); + } + + @Test + public void testSeconds() { + Assert.assertEquals(new Duration(30 * 1000), Durations.seconds(30)); + } + + @Test + public void testMinutes() { + Assert.assertEquals(new Duration(2 * 60 * 1000), Durations.minutes(2)); + } + +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java new file mode 100644 index 0000000000000..ad6b1853e3d12 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaTimeSuite { + + // Just testing the methods that are specially exposed for Java. + // This does not repeat all tests found in the Scala suite. + + @Test + public void testLess() { + Assert.assertTrue(new Time(999).less(new Time(1000))); + } + + @Test + public void testLessEq() { + Assert.assertTrue(new Time(1000).lessEq(new Time(1000))); + } + + @Test + public void testGreater() { + Assert.assertTrue(new Time(1000).greater(new Time(999))); + } + + @Test + public void testGreaterEq() { + Assert.assertTrue(new Time(1000).greaterEq(new Time(1000))); + } + + @Test + public void testPlus() { + Assert.assertEquals(new Time(1100), new Time(1000).plus(new Duration(100))); + } + + @Test + public void testMinusTime() { + Assert.assertEquals(new Duration(900), new Time(1000).minus(new Time(100))); + } + + @Test + public void testMinusDuration() { + Assert.assertEquals(new Time(900), new Time(1000).minus(new Duration(100))); + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala new file mode 100644 index 0000000000000..6202250e897f2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +class DurationSuite extends TestSuiteBase { + + test("less") { + assert(new Duration(999) < new Duration(1000)) + assert(new Duration(0) < new Duration(1)) + assert(!(new Duration(1000) < new Duration(999))) + assert(!(new Duration(1000) < new Duration(1000))) + } + + test("lessEq") { + assert(new Duration(999) <= new Duration(1000)) + assert(new Duration(0) <= new Duration(1)) + assert(!(new Duration(1000) <= new Duration(999))) + assert(new Duration(1000) <= new Duration(1000)) + } + + test("greater") { + assert(!(new Duration(999) > new Duration(1000))) + assert(!(new Duration(0) > new Duration(1))) + assert(new Duration(1000) > new Duration(999)) + assert(!(new Duration(1000) > new Duration(1000))) + } + + test("greaterEq") { + assert(!(new Duration(999) >= new Duration(1000))) + assert(!(new Duration(0) >= new Duration(1))) + assert(new Duration(1000) >= new Duration(999)) + assert(new Duration(1000) >= new Duration(1000)) + } + + test("plus") { + assert((new Duration(1000) + new Duration(100)) == new Duration(1100)) + assert((new Duration(1000) + new Duration(0)) == new Duration(1000)) + } + + test("minus") { + assert((new Duration(1000) - new Duration(100)) == new Duration(900)) + assert((new Duration(1000) - new Duration(0)) == new Duration(1000)) + assert((new Duration(1000) - new Duration(1000)) == new Duration(0)) + } + + test("times") { + assert((new Duration(100) * 2) == new Duration(200)) + assert((new Duration(100) * 1) == new Duration(100)) + assert((new Duration(100) * 0) == new Duration(0)) + } + + test("div") { + assert((new Duration(1000) / new Duration(5)) == 200.0) + assert((new Duration(1000) / new Duration(1)) == 1000.0) + assert((new Duration(1000) / new Duration(1000)) == 1.0) + assert((new Duration(1000) / new Duration(2000)) == 0.5) + } + + test("isMultipleOf") { + assert(new Duration(1000).isMultipleOf(new Duration(5))) + assert(new Duration(1000).isMultipleOf(new Duration(1000))) + assert(new Duration(1000).isMultipleOf(new Duration(1))) + assert(!new Duration(1000).isMultipleOf(new Duration(6))) + } + + test("min") { + assert(new Duration(999).min(new Duration(1000)) == new Duration(999)) + assert(new Duration(1000).min(new Duration(999)) == new Duration(999)) + assert(new Duration(1000).min(new Duration(1000)) == new Duration(1000)) + } + + test("max") { + assert(new Duration(999).max(new Duration(1000)) == new Duration(1000)) + assert(new Duration(1000).max(new Duration(999)) == new Duration(1000)) + assert(new Duration(1000).max(new Duration(1000)) == new Duration(1000)) + } + + test("isZero") { + assert(new Duration(0).isZero) + assert(!(new Duration(1).isZero)) + } + + test("Milliseconds") { + assert(new Duration(100) == Milliseconds(100)) + } + + test("Seconds") { + assert(new Duration(30 * 1000) == Seconds(30)) + } + + test("Minutes") { + assert(new Duration(2 * 60 * 1000) == Minutes(2)) + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala new file mode 100644 index 0000000000000..5579ac364346c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +class TimeSuite extends TestSuiteBase { + + test("less") { + assert(new Time(999) < new Time(1000)) + assert(new Time(0) < new Time(1)) + assert(!(new Time(1000) < new Time(999))) + assert(!(new Time(1000) < new Time(1000))) + } + + test("lessEq") { + assert(new Time(999) <= new Time(1000)) + assert(new Time(0) <= new Time(1)) + assert(!(new Time(1000) <= new Time(999))) + assert(new Time(1000) <= new Time(1000)) + } + + test("greater") { + assert(!(new Time(999) > new Time(1000))) + assert(!(new Time(0) > new Time(1))) + assert(new Time(1000) > new Time(999)) + assert(!(new Time(1000) > new Time(1000))) + } + + test("greaterEq") { + assert(!(new Time(999) >= new Time(1000))) + assert(!(new Time(0) >= new Time(1))) + assert(new Time(1000) >= new Time(999)) + assert(new Time(1000) >= new Time(1000)) + } + + test("plus") { + assert((new Time(1000) + new Duration(100)) == new Time(1100)) + assert((new Time(1000) + new Duration(0)) == new Time(1000)) + } + + test("minus Time") { + assert((new Time(1000) - new Time(100)) == new Duration(900)) + assert((new Time(1000) - new Time(0)) == new Duration(1000)) + assert((new Time(1000) - new Time(1000)) == new Duration(0)) + } + + test("minus Duration") { + assert((new Time(1000) - new Duration(100)) == new Time(900)) + assert((new Time(1000) - new Duration(0)) == new Time(1000)) + assert((new Time(1000) - new Duration(1000)) == new Time(0)) + } + + test("floor") { + assert(new Time(1350).floor(new Duration(200)) == new Time(1200)) + assert(new Time(1200).floor(new Duration(200)) == new Time(1200)) + assert(new Time(199).floor(new Duration(200)) == new Time(0)) + assert(new Time(1).floor(new Duration(1)) == new Time(1)) + } + + test("isMultipleOf") { + assert(new Time(1000).isMultipleOf(new Duration(5))) + assert(new Time(1000).isMultipleOf(new Duration(1000))) + assert(new Time(1000).isMultipleOf(new Duration(1))) + assert(!new Time(1000).isMultipleOf(new Duration(6))) + } + + test("min") { + assert(new Time(999).min(new Time(1000)) == new Time(999)) + assert(new Time(1000).min(new Time(999)) == new Time(999)) + assert(new Time(1000).min(new Time(1000)) == new Time(1000)) + } + + test("max") { + assert(new Time(999).max(new Time(1000)) == new Time(1000)) + assert(new Time(1000).max(new Time(999)) == new Time(1000)) + assert(new Time(1000).max(new Time(1000)) == new Time(1000)) + } + + test("until") { + assert(new Time(1000).until(new Time(1100), new Duration(100)) == + Seq(Time(1000))) + assert(new Time(1000).until(new Time(1000), new Duration(100)) == + Seq()) + assert(new Time(1000).until(new Time(1100), new Duration(30)) == + Seq(Time(1000), Time(1030), Time(1060), Time(1090))) + } + + test("to") { + assert(new Time(1000).to(new Time(1100), new Duration(100)) == + Seq(Time(1000), Time(1100))) + assert(new Time(1000).to(new Time(1000), new Duration(100)) == + Seq(Time(1000))) + assert(new Time(1000).to(new Time(1100), new Duration(30)) == + Seq(Time(1000), Time(1030), Time(1060), Time(1090))) + } + +} From ae60f8fb2d879ee1ebc0746bcbe05b89ab6ed3c9 Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 23 Sep 2014 11:59:44 -0700 Subject: [PATCH 069/315] [SPARK-3481][SQL] removes the evil MINOR HACK a follow up of https://github.com/apache/spark/pull/2377 and https://github.com/apache/spark/pull/2352, see detail there. Author: wangfei Closes #2505 from scwf/patch-6 and squashes the following commits: 4874ec8 [wangfei] removes the evil MINOR HACK --- .../org/apache/spark/sql/hive/execution/PruningSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8275e2d3bcce3..8474d850c9c6c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -28,8 +28,6 @@ import scala.collection.JavaConversions._ * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") TestHive.cacheTables = false // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset From 1c62f97e94de96ca3dc6daf778f008176e92888a Mon Sep 17 00:00:00 2001 From: Venkata Ramana Gollamudi Date: Tue, 23 Sep 2014 12:17:47 -0700 Subject: [PATCH 070/315] [SPARK-3268][SQL] DoubleType, FloatType and DecimalType modulus support Supported modulus operation using % operator on fractional datatypes FloatType, DoubleType and DecimalType Example: SELECT 1388632775.0 % 60 from tablename LIMIT 1 Author : Venkata Ramana Gollamudi ramana.gollamudihuawei.com Author: Venkata Ramana Gollamudi Closes #2457 from gvramana/double_modulus_support and squashes the following commits: 79172a8 [Venkata Ramana Gollamudi] Add hive cache to testcase c09bd5b [Venkata Ramana Gollamudi] Added a HiveQuerySuite testcase 193fa81 [Venkata Ramana Gollamudi] corrected testcase 3624471 [Venkata Ramana Gollamudi] modified testcase e112c09 [Venkata Ramana Gollamudi] corrected the testcase 513d0e0 [Venkata Ramana Gollamudi] modified to add modulus support to fractional types float,double,decimal 296d253 [Venkata Ramana Gollamudi] modified to add modulus support to fractional types float,double,decimal --- .../sql/catalyst/expressions/Expression.scala | 3 ++ .../spark/sql/catalyst/types/dataTypes.scala | 5 +++ .../ExpressionEvaluationSuite.scala | 32 +++++++++++++++++++ ...modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 3 ++ 5 files changed, 44 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 70507e7ee2be8..1eb260efa6387 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -179,6 +179,9 @@ abstract class Expression extends TreeNode[Expression] { case i: IntegralType => f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case i: FractionalType => + f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( + i.asIntegral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) case other => sys.error(s"Type $other does not support numeric operations") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index e3050e5397937..c7d73d3990c3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp +import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers @@ -250,6 +251,7 @@ object FractionalType { } abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] + private[sql] val asIntegral: Integral[JvmType] } case object DecimalType extends FractionalType { @@ -258,6 +260,7 @@ case object DecimalType extends FractionalType { private[sql] val numeric = implicitly[Numeric[BigDecimal]] private[sql] val fractional = implicitly[Fractional[BigDecimal]] private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = BigDecimalAsIfIntegral def simpleString: String = "decimal" } @@ -267,6 +270,7 @@ case object DoubleType extends FractionalType { private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = DoubleAsIfIntegral def simpleString: String = "double" } @@ -276,6 +280,7 @@ case object FloatType extends FractionalType { private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = FloatAsIfIntegral def simpleString: String = "float" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 8b6721d5d8125..63931af4bac3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp import org.scalatest.FunSuite +import org.scalatest.Matchers._ +import org.scalautils.TripleEqualsSupport.Spread import org.apache.spark.sql.catalyst.types._ @@ -129,6 +131,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } + test("IN") { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) @@ -471,6 +480,29 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 % c2, 1, row) } + test("fractional arithmetic") { + val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val c1 = 'a.double.at(0) + val c2 = 'a.double.at(1) + val c3 = 'a.double.at(2) + val c4 = 'a.double.at(3) + + checkEvaluation(UnaryMinus(c1), -1.1, row) + checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0) + checkEvaluation(Add(c1, c4), null, row) + checkEvaluation(Add(c1, c2), 3.1, row) + checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row) + checkEvaluation(Add(Literal(null, DoubleType), c2), null, row) + checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row) + + checkEvaluation(-c1, -1.1, row) + checkEvaluation(c1 + c2, 3.1, row) + checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) + checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) + checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) + checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) + } + test("BinaryComparison") { val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) val c1 = 'a.int.at(0) diff --git a/sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e b/sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e new file mode 100644 index 0000000000000..52eab0653c505 --- /dev/null +++ b/sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e @@ -0,0 +1 @@ +1 true 0.5 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6fc891ba4cca5..426f5fcee6157 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -138,6 +138,9 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("division", "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") + createQueryTest("modulus", + "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") + test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) From a08153f8a3e7bad81bae330ec4152651da5e7804 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 23 Sep 2014 12:27:12 -0700 Subject: [PATCH 071/315] [SPARK-3646][SQL] Copy SQL configuration from SparkConf when a SQLContext is created. This will allow us to take advantage of things like the spark.defaults file. Author: Michael Armbrust Closes #2493 from marmbrus/copySparkConf and squashes the following commits: 0bd1377 [Michael Armbrust] Copy SQL configuration from SparkConf when a SQLContext is created. --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 5 +++++ .../org/apache/spark/sql/test/TestSQLContext.scala | 6 +++++- .../scala/org/apache/spark/sql/SQLConfSuite.scala | 11 ++++++++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b245e1a863cc3..a42bedbe6c04e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -75,6 +75,11 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => setConf(key, value) + case _ => + } + /** * :: DeveloperApi :: * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 265b67737c475..6bb81c76ed8bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -22,7 +22,11 @@ import org.apache.spark.sql.{SQLConf, SQLContext} /** A SQLContext that can be used for local testing. */ object TestSQLContext - extends SQLContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) { + extends SQLContext( + new SparkContext( + "local[2]", + "TestSQLContext", + new SparkConf().set("spark.sql.testkey", "true"))) { /** Fewer partitions to speed up testing. */ override private[spark] def numShufflePartitions: Int = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 584f71b3c13d5..60701f0e154f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,16 +17,25 @@ package org.apache.spark.sql +import org.scalatest.FunSuiteLike + import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ -class SQLConfSuite extends QueryTest { +class SQLConfSuite extends QueryTest with FunSuiteLike { val testKey = "test.key.0" val testVal = "test.val.0" + test("propagate from spark conf") { + // We create a new context here to avoid order dependence with other tests that might call + // clear(). + val newContext = new SQLContext(TestSQLContext.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") == "true") + } + test("programmatic ways of basic setting and getting") { clear() assert(getAllConfs.size === 0) From 8dfe79ffb204807945e3c09b75c7255b09ad2a97 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Sep 2014 13:42:00 -0700 Subject: [PATCH 072/315] [SPARK-3647] Add more exceptions to Guava relocation. Guava's Optional refers to some package private classes / methods, and when those are relocated the code stops working, throwing exceptions. So add the affected classes to the exception list too, and add a unit test. (Note that this unit test only really makes sense in maven, since we don't relocate in the sbt build. Also, JavaAPISuite doesn't seem to be run by "mvn test" - I had to manually add command line options to enable it.) Author: Marcelo Vanzin Closes #2496 from vanzin/SPARK-3647 and squashes the following commits: 84f58d7 [Marcelo Vanzin] [SPARK-3647] Add more exceptions to Guava relocation. --- assembly/pom.xml | 4 ++- core/pom.xml | 2 ++ .../java/org/apache/spark/JavaAPISuite.java | 26 +++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 604b1ab3de6a8..5ec9da22ae83f 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -141,7 +141,9 @@ com.google.common.** - com.google.common.base.Optional** + com/google/common/base/Absent* + com/google/common/base/Optional* + com/google/common/base/Present* diff --git a/core/pom.xml b/core/pom.xml index 2a81f6df289c0..e012c5e673b74 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -343,7 +343,9 @@ com.google.guava:guava + com/google/common/base/Absent* com/google/common/base/Optional* + com/google/common/base/Present* diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8574dfb42e6b..b8c23d524e00b 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1307,4 +1307,30 @@ public void collectUnderlyingScalaRDD() { SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); Assert.assertEquals(data.size(), collected.length); } + + /** + * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, + * since that's the only artifact where Guava classes have been relocated. + */ + @Test + public void testGuavaOptional() { + // Stop the context created in setUp() and start a local-cluster one, to force usage of the + // assembly. + sc.stop(); + JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite"); + try { + JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); + JavaRDD> rdd2 = rdd1.map( + new Function>() { + @Override + public Optional call(Integer i) { + return Optional.fromNullable(i); + } + }); + rdd2.collect(); + } finally { + localCluster.stop(); + } + } + } From d79238d03a2ffe0cf5fc6166543d67768693ddbe Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 23 Sep 2014 13:44:18 -0700 Subject: [PATCH 073/315] SPARK-3612. Executor shouldn't quit if heartbeat message fails to reach ... ...the driver Author: Sandy Ryza Closes #2487 from sryza/sandy-spark-3612 and squashes the following commits: 2b7353d [Sandy Ryza] SPARK-3612. Executor shouldn't quit if heartbeat message fails to reach the driver --- .../org/apache/spark/executor/Executor.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index acae448a9c66f..d7211ae465902 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -24,6 +24,7 @@ import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -375,12 +376,17 @@ private[spark] class Executor( } val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) - if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") - env.blockManager.reregister() + try { + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + } catch { + case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t) } + Thread.sleep(interval) } } From b3fef50e22fb3fe499f627179d17836a92dcb33a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 23 Sep 2014 14:00:33 -0700 Subject: [PATCH 074/315] [SPARK-3653] Respect SPARK_*_MEMORY for cluster mode `SPARK_DRIVER_MEMORY` was only used to start the `SparkSubmit` JVM, which becomes the driver only in client mode but not cluster mode. In cluster mode, this property is simply not propagated to the worker nodes. `SPARK_EXECUTOR_MEMORY` is picked up from `SparkContext`, but in cluster mode the driver runs on one of the worker machines, where this environment variable may not be set. Author: Andrew Or Closes #2500 from andrewor14/memory-env-vars and squashes the following commits: 6217b38 [Andrew Or] Respect SPARK_*_MEMORY for cluster mode --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 92e0917743ed1..2b72c61cc8177 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -75,6 +75,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { defaultProperties } + // Respect SPARK_*_MEMORY for cluster mode + driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull + executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull + parseOpts(args.toList) mergeSparkProperties() checkRequiredArguments() From 729952a5efce755387c76cdf29280ee6f49fdb72 Mon Sep 17 00:00:00 2001 From: Mubarak Seyed Date: Tue, 23 Sep 2014 15:09:12 -0700 Subject: [PATCH 075/315] [SPARK-1853] Show Streaming application code context (file, line number) in Spark Stages UI This is a refactored version of the original PR https://github.com/apache/spark/pull/1723 my mubarak Please take a look andrewor14, mubarak Author: Mubarak Seyed Author: Tathagata Das Closes #2464 from tdas/streaming-callsite and squashes the following commits: dc54c71 [Tathagata Das] Made changes based on PR comments. 390b45d [Tathagata Das] Fixed minor bugs. 904cd92 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into streaming-callsite 7baa427 [Tathagata Das] Refactored getCallSite and setCallSite to make it simpler. Also added unit test for DStream creation site. b9ed945 [Mubarak Seyed] Adding streaming utils c461cf4 [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' ceb43da [Mubarak Seyed] Changing default regex function name 8c5d443 [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' 196121b [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' 491a1eb [Mubarak Seyed] Removing streaming visibility from getRDDCreationCallSite in DStream 33a7295 [Mubarak Seyed] Fixing review comments: Merging both setCallSite methods c26d933 [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' f51fd9f [Mubarak Seyed] Fixing scalastyle, Regex for Utils.getCallSite, and changing method names in DStream 5051c58 [Mubarak Seyed] Getting return value of compute() into variable and call setCallSite(prevCallSite) only once. Adding return for other code paths (for None) a207eb7 [Mubarak Seyed] Fixing code review comments ccde038 [Mubarak Seyed] Removing Utils import from MappedDStream 2a09ad6 [Mubarak Seyed] Changes in Utils.scala for SPARK-1853 1d90cc3 [Mubarak Seyed] Changes for SPARK-1853 5f3105a [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' 70f494f [Mubarak Seyed] Changes for SPARK-1853 1500deb [Mubarak Seyed] Changes in Spark Streaming UI 9d38d3c [Mubarak Seyed] [SPARK-1853] Show Streaming application code context (file, line number) in Spark Stages UI d466d75 [Mubarak Seyed] Changes for spark streaming UI --- .../scala/org/apache/spark/SparkContext.scala | 32 +++++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 7 +- .../scala/org/apache/spark/util/Utils.scala | 27 ++++-- .../spark/streaming/StreamingContext.scala | 4 +- .../spark/streaming/dstream/DStream.scala | 96 ++++++++++++------- .../streaming/StreamingContextSuite.scala | 45 ++++++++- 6 files changed, 153 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 428f019b02a23..979d178c35969 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1030,28 +1030,40 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Support function for API backtraces. + * Set the thread-local property for overriding the call sites + * of actions and RDDs. */ - def setCallSite(site: String) { - setLocalProperty("externalCallSite", site) + def setCallSite(shortCallSite: String) { + setLocalProperty(CallSite.SHORT_FORM, shortCallSite) } /** - * Support function for API backtraces. + * Set the thread-local property for overriding the call sites + * of actions and RDDs. + */ + private[spark] def setCallSite(callSite: CallSite) { + setLocalProperty(CallSite.SHORT_FORM, callSite.shortForm) + setLocalProperty(CallSite.LONG_FORM, callSite.longForm) + } + + /** + * Clear the thread-local property for overriding the call sites + * of actions and RDDs. */ def clearCallSite() { - setLocalProperty("externalCallSite", null) + setLocalProperty(CallSite.SHORT_FORM, null) + setLocalProperty(CallSite.LONG_FORM, null) } /** * Capture the current user callsite and return a formatted version for printing. If the user - * has overridden the call site, this will return the user's version. + * has overridden the call site using `setCallSite()`, this will return the user's version. */ private[spark] def getCallSite(): CallSite = { - Option(getLocalProperty("externalCallSite")) match { - case Some(callSite) => CallSite(callSite, longForm = "") - case None => Utils.getCallSite - } + Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite => + val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("") + CallSite(shortCallSite, longCallSite) + }.getOrElse(Utils.getCallSite()) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a9b905b0d1a63..0e90caa5c9ca7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.util.Random +import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer @@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1224,7 +1224,8 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ - @transient private[spark] val creationSite = Utils.getCallSite + @transient private[spark] val creationSite = sc.getCallSite() + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ed063844323af..2755887feeeff 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -49,6 +49,11 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) +private[spark] object CallSite { + val SHORT_FORM = "callSite.short" + val LONG_FORM = "callSite.long" +} + /** * Various utility methods used by Spark. */ @@ -859,18 +864,26 @@ private[spark] object Utils extends Logging { } } - /** - * A regular expression to match classes of the "core" Spark API that we want to skip when - * finding the call site of a method. - */ - private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + /** Default filtering function for finding call sites using `getCallSite`. */ + private def coreExclusionFunction(className: String): Boolean = { + // A regular expression to match classes of the "core" Spark API that we want to skip when + // finding the call site of a method. + val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + val SCALA_CLASS_REGEX = """^scala""".r + val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined + val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined + // If the class is a Spark internal class or a Scala class, then exclude. + isSparkCoreClass || isScalaClass + } /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. + * + * @param skipClass Function that is used to exclude non-user-code classes. */ - def getCallSite: CallSite = { + def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { val trace = Thread.currentThread.getStackTrace() .filterNot { ste:StackTraceElement => // When running under some profilers, the current stack trace might contain some bogus @@ -891,7 +904,7 @@ private[spark] object Utils extends Logging { for (el <- trace) { if (insideSpark) { - if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) { + if (skipClass(el.getClassName)) { lastSparkMethod = if (el.getMethodName == "") { // Spark method is a constructor; get its class name el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f63560dcb5b89..5a8eef1372e23 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -35,10 +35,9 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver} +import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.MetadataCleaner /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -448,6 +447,7 @@ class StreamingContext private[streaming] ( throw new SparkException("StreamingContext has already been stopped") } validate() + sparkContext.setCallSite(DStream.getCreationSite()) scheduler.start() state = Started } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index e05db236addca..65f7ccd318684 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -23,6 +23,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.deprecated import scala.collection.mutable.HashMap import scala.reflect.ClassTag +import scala.util.matching.Regex import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.{BlockRDD, RDD} @@ -30,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{CallSite, MetadataCleaner} /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -106,6 +107,9 @@ abstract class DStream[T: ClassTag] ( /** Return the StreamingContext associated with this DStream */ def context = ssc + /* Set the creation call site */ + private[streaming] val creationSite = DStream.getCreationSite() + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { @@ -272,43 +276,41 @@ abstract class DStream[T: ClassTag] ( } /** - * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal - * method that should not be called directly. + * Get the RDD corresponding to the given time; either retrieve it from cache + * or compute-and-cache it. */ private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { - // If this DStream was not initialized (i.e., zeroTime not set), then do it - // If RDD was already generated, then retrieve it from HashMap - generatedRDDs.get(time) match { - - // If an RDD was already generated and is being reused, then - // probably all RDDs in this DStream will be reused and hence should be cached - case Some(oldRDD) => Some(oldRDD) - - // if RDD was not generated, and if the time is valid - // (based on sliding time of this DStream), then generate the RDD - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => - if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting RDD " + newRDD.id + " for time " + - time + " to " + storageLevel + " at time " + time) - } - if (checkpointDuration != null && - (time - zeroTime).isMultipleOf(checkpointDuration)) { - newRDD.checkpoint() - logInfo("Marking RDD " + newRDD.id + " for time " + time + - " for checkpointing at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - case None => - None + // If RDD was already generated, then retrieve it from HashMap, + // or else compute the RDD + generatedRDDs.get(time).orElse { + // Compute the RDD if time is valid (e.g. correct time in a sliding window) + // of RDD generation, else generate nothing. + if (isTimeValid(time)) { + // Set the thread-local property for call sites to this DStream's creation site + // such that RDDs generated by compute gets that as their creation site. + // Note that this `getOrCompute` may get called from another DStream which may have + // set its own call site. So we store its call site in a temporary variable, + // set this DStream's creation site, generate RDDs and then restore the previous call site. + val prevCallSite = ssc.sparkContext.getCallSite() + ssc.sparkContext.setCallSite(creationSite) + val rddOption = compute(time) + ssc.sparkContext.setCallSite(prevCallSite) + + rddOption.foreach { case newRDD => + // Register the generated RDD for caching and checkpointing + if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel") } - } else { - None + if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) { + newRDD.checkpoint() + logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing") + } + generatedRDDs.put(time, newRDD) } + rddOption + } else { + None } } } @@ -799,3 +801,29 @@ abstract class DStream[T: ClassTag] ( this } } + +private[streaming] object DStream { + + /** Get the creation site of a DStream from the stack trace of when the DStream is created. */ + def getCreationSite(): CallSite = { + val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r + val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r + val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r + val SCALA_CLASS_REGEX = """^scala""".r + + /** Filtering function that excludes non-user classes for a streaming application */ + def streamingExclustionFunction(className: String): Boolean = { + def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + val isSparkClass = doesMatch(SPARK_CLASS_REGEX) + val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) + val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) + val isScalaClass = doesMatch(SCALA_CLASS_REGEX) + + // If the class is a spark example class or a streaming test class then it is considered + // as a streaming application class and don't exclude. Otherwise, exclude any + // non-Spark and non-Scala class, as the rest would streaming application classes. + (isSparkClass || isScalaClass) && !isSparkExampleClass && !isSparkStreamingTestClass + } + org.apache.spark.util.Utils.getCallSite(streamingExclustionFunction) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index a3cabd6be02fe..ebf83748ffa28 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -19,13 +19,16 @@ package org.apache.spark.streaming import java.util.concurrent.atomic.AtomicInteger +import scala.language.postfixOps + import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.{MetadataCleaner, Utils} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.util.Utils +import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.Eventually._ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -257,6 +260,10 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } + test("DStream and generated RDD creation sites") { + testPackage.test() + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => (1 to i)) val inputStream = new TestInputStream(s, input, 1) @@ -293,3 +300,37 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging object TestReceiver { val counter = new AtomicInteger(1) } + +/** Streaming application for testing DStream and RDD creation sites */ +package object testPackage extends Assertions { + def test() { + val conf = new SparkConf().setMaster("local").setAppName("CreationSite test") + val ssc = new StreamingContext(conf , Milliseconds(100)) + try { + val inputStream = ssc.receiverStream(new TestReceiver) + + // Verify creation site of DStream + val creationSite = inputStream.creationSite + assert(creationSite.shortForm.contains("receiverStream") && + creationSite.shortForm.contains("StreamingContextSuite") + ) + assert(creationSite.longForm.contains("testPackage")) + + // Verify creation site of generated RDDs + var rddGenerated = false + var rddCreationSiteCorrect = true + + inputStream.foreachRDD { rdd => + rddCreationSiteCorrect = rdd.creationSite == creationSite + rddGenerated = true + } + ssc.start() + + eventually(timeout(10000 millis), interval(10 millis)) { + assert(rddGenerated && rddCreationSiteCorrect, "RDD creation site was not correct") + } + } finally { + ssc.stop() + } + } +} From c429126066f766396b706894b6942f1ca7fcb528 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 24 Sep 2014 11:33:58 -0700 Subject: [PATCH 076/315] [Build] Diff from branch point Sometimes Jenkins posts [spurious reports of new classes being added](https://github.com/apache/spark/pull/2339#issuecomment-56570170). I believe this stems from diffing the patch against `master`, as opposed to against `master...`, which starts from the commit the PR was branched from. This patch fixes that behavior. Author: Nicholas Chammas Closes #2512 from nchammas/diff-only-commits-ahead and squashes the following commits: c065599 [Nicholas Chammas] comment typo fix a453c67 [Nicholas Chammas] diff from branch point --- dev/run-tests-jenkins | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 06c3781eb3ccf..a6ecf3196d7d4 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -92,13 +92,13 @@ function post_message () { merge_note=" * This patch merges cleanly." source_files=$( - git diff master --name-only \ + git diff master... --name-only `# diff patch against master from branch point` \ | grep -v -e "\/test" `# ignore files in test directories` \ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ | tr "\n" " " ) new_public_classes=$( - git diff master ${source_files} `# diff this patch against master and...` \ + git diff master... ${source_files} `# diff patch against master from branch point` \ | grep "^\+" `# filter in only added lines` \ | sed -r -e "s/^\+//g" `# remove the leading +` \ | grep -e "trait " -e "class " `# filter in lines with these key words` \ From 50f863365348d52a9285fc779efbedbf1567ea11 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 24 Sep 2014 11:34:39 -0700 Subject: [PATCH 077/315] [SPARK-3659] Set EC2 version to 1.1.0 and update version map This brings the master branch in sync with branch-1.1 Author: Shivaram Venkataraman Closes #2510 from shivaram/spark-ec2-version and squashes the following commits: bb0dd16 [Shivaram Venkataraman] Set EC2 version to 1.1.0 and update version map --- ec2/spark_ec2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index fbeccd89b43b3..7f2cd7d94de39 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -38,7 +38,7 @@ from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType from boto import ec2 -DEFAULT_SPARK_VERSION = "1.0.0" +DEFAULT_SPARK_VERSION = "1.1.0" # A URL prefix from which to fetch AMI information AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" @@ -218,7 +218,7 @@ def is_active(instance): def get_spark_shark_version(opts): spark_shark_map = { "0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1", - "1.0.0": "1.0.0" + "1.0.0": "1.0.0", "1.0.1": "1.0.1", "1.0.2": "1.0.2", "1.1.0": "1.1.0" } version = opts.spark_version.replace("v", "") if version not in spark_shark_map: From c854b9fcb5595b1d70b6ce257fc7574602ac5e49 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 24 Sep 2014 12:10:09 -0700 Subject: [PATCH 078/315] [SPARK-3634] [PySpark] User's module should take precedence over system modules Python modules added through addPyFile should take precedence over system modules. This patch put the path for user added module in the front of sys.path (just after ''). Author: Davies Liu Closes #2492 from davies/path and squashes the following commits: 4a2af78 [Davies Liu] fix tests f7ff4da [Davies Liu] ad license header 6b0002f [Davies Liu] add tests c16c392 [Davies Liu] put addPyFile in front of sys.path --- python/pyspark/context.py | 11 +++++------ python/pyspark/tests.py | 12 ++++++++++++ python/pyspark/worker.py | 11 +++++++++-- python/test_support/SimpleHTTPServer.py | 22 ++++++++++++++++++++++ 4 files changed, 48 insertions(+), 8 deletions(-) create mode 100644 python/test_support/SimpleHTTPServer.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 064a24bff539c..8e7b00469e246 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -171,7 +171,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() - sys.path.append(root_dir) + sys.path.insert(1, root_dir) # Deploy any code dependencies specified in the constructor self._python_includes = list() @@ -183,10 +183,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - self._python_includes.append(filename) - sys.path.append(path) - if dirname not in sys.path: - sys.path.append(dirname) + if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + self._python_includes.append(filename) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -667,7 +666,7 @@ def addPyFile(self, path): if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): self._python_includes.append(filename) # for tests in local mode - sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) def setCheckpointDir(self, dirName): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1b8afb763b26a..4483bf80dbe06 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,18 @@ def func(): from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + class TestRDDFunctions(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d6c06e2dbef62..c1f6e3e4a1f40 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish): write_long(1000 * finish, outfile) +def add_path(path): + # worker can be used, so donot add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + def main(infile, outfile): try: boot_time = time.time() @@ -61,11 +68,11 @@ def main(infile, outfile): SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here + add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) - sys.path.append(os.path.join(spark_files_dir, filename)) + add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) diff --git a/python/test_support/SimpleHTTPServer.py b/python/test_support/SimpleHTTPServer.py new file mode 100644 index 0000000000000..eddbd588e02dc --- /dev/null +++ b/python/test_support/SimpleHTTPServer.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Used to test override standard SimpleHTTPServer module. +""" + +__name__ = "My Server" From bb96012b7360b099a19fecc80f0209b30f118ada Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 24 Sep 2014 13:00:05 -0700 Subject: [PATCH 079/315] [SPARK-3679] [PySpark] pickle the exact globals of functions function.func_code.co_names has all the names used in the function, including name of attributes. It will pickle some unnecessary globals if there is a global having the same name with attribute (in co_names). There is a regression introduced by #2144, revert part of changes in that PR. cc JoshRosen Author: Davies Liu Closes #2522 from davies/globals and squashes the following commits: dfbccf5 [Davies Liu] fix bug while pickle globals of function --- python/pyspark/cloudpickle.py | 42 ++++++++++++++++++++++++++++++----- python/pyspark/tests.py | 18 +++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 32dda3888c62d..bb0783555aa77 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -52,6 +52,7 @@ import itertools from copy_reg import _extension_registry, _inverted_registry, _extension_cache import new +import dis import traceback import platform @@ -61,6 +62,14 @@ import logging cloudLog = logging.getLogger("Cloud.Transport") +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) if PyImp == "PyPy": # register builtin type in `new` @@ -304,16 +313,37 @@ def save_function_tuple(self, func, forced_imports): write(pickle.REDUCE) # applies _fill_function on the tuple @staticmethod - def extract_code_globals(code): + def extract_code_globals(co): """ Find all globals names read or written to by codeblock co """ - names = set(code.co_names) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + + if co.co_consts: # see if nested function have any global refs + for const in co.co_consts: if type(const) is types.CodeType: - names |= CloudPickler.extract_code_globals(const) - return names + out_names |= CloudPickler.extract_code_globals(const) + + return out_names def extract_func_data(self, func): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4483bf80dbe06..d1bb2033b7a16 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -213,6 +213,24 @@ def test_pickling_file_handles(self): out2 = ser.loads(ser.dumps(out1)) self.assertEquals(out1, out2) + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.func_code.co_names) + ser.dumps(foo) + class PySparkTestCase(unittest.TestCase): From 74fb2ecf7afc2d314f6477f8f2e6134614387453 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 24 Sep 2014 17:18:55 -0700 Subject: [PATCH 080/315] [SPARK-3615][Streaming]Fix Kafka unit test hard coded Zookeeper port issue Details can be seen in [SPARK-3615](https://issues.apache.org/jira/browse/SPARK-3615). Author: jerryshao Closes #2483 from jerryshao/SPARK_3615 and squashes the following commits: 8555563 [jerryshao] Fix Kafka unit test hard coded Zookeeper port issue --- .../streaming/kafka/JavaKafkaStreamSuite.java | 2 +- .../streaming/kafka/KafkaStreamSuite.scala | 46 +++++++++++++------ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 0571454c01dae..efb0099c7c850 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -81,7 +81,7 @@ public void testKafkaStream() throws InterruptedException { Predef.>conforms())); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", testSuite.zkConnect()); + kafkaParams.put("zookeeper.connect", testSuite.zkHost() + ":" + testSuite.zkPort()); kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index c0b55e9340253..6943326eb750e 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -24,7 +24,7 @@ import java.util.{Properties, Random} import scala.collection.mutable import kafka.admin.CreateTopicCommand -import kafka.common.TopicAndPartition +import kafka.common.{KafkaException, TopicAndPartition} import kafka.producer.{KeyedMessage, ProducerConfig, Producer} import kafka.utils.ZKStringSerializer import kafka.serializer.{StringDecoder, StringEncoder} @@ -42,14 +42,13 @@ import org.apache.spark.util.Utils class KafkaStreamSuite extends TestSuiteBase { import KafkaTestUtils._ - val zkConnect = "localhost:2181" + val zkHost = "localhost" + var zkPort: Int = 0 val zkConnectionTimeout = 6000 val zkSessionTimeout = 6000 - val brokerPort = 9092 - val brokerProps = getBrokerConfig(brokerPort, zkConnect) - val brokerConf = new KafkaConfig(brokerProps) - + protected var brokerPort = 9092 + protected var brokerConf: KafkaConfig = _ protected var zookeeper: EmbeddedZookeeper = _ protected var zkClient: ZkClient = _ protected var server: KafkaServer = _ @@ -59,16 +58,35 @@ class KafkaStreamSuite extends TestSuiteBase { override def beforeFunction() { // Zookeeper server startup - zookeeper = new EmbeddedZookeeper(zkConnect) + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort logInfo("==================== 0 ====================") - zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + + zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + ZKStringSerializer) logInfo("==================== 1 ====================") // Kafka broker startup - server = new KafkaServer(brokerConf) - logInfo("==================== 2 ====================") - server.startup() - logInfo("==================== 3 ====================") + var bindSuccess: Boolean = false + while(!bindSuccess) { + try { + val brokerProps = getBrokerConfig(brokerPort, s"$zkHost:$zkPort") + brokerConf = new KafkaConfig(brokerProps) + server = new KafkaServer(brokerConf) + logInfo("==================== 2 ====================") + server.startup() + logInfo("==================== 3 ====================") + bindSuccess = true + } catch { + case e: KafkaException => + if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) { + brokerPort += 1 + } + case e: Exception => throw new Exception("Kafka server create failed", e) + } + } + Thread.sleep(2000) logInfo("==================== 4 ====================") super.beforeFunction() @@ -92,7 +110,7 @@ class KafkaStreamSuite extends TestSuiteBase { createTopic(topic) produceAndSendMessage(topic, sent) - val kafkaParams = Map("zookeeper.connect" -> zkConnect, + val kafkaParams = Map("zookeeper.connect" -> s"$zkHost:$zkPort", "group.id" -> s"test-consumer-${random.nextInt(10000)}", "auto.offset.reset" -> "smallest") @@ -200,6 +218,8 @@ object KafkaTestUtils { factory.configure(new InetSocketAddress(ip, port), 16) factory.startup(zookeeper) + val actualPort = factory.getLocalPort + def shutdown() { factory.shutdown() Utils.deleteRecursively(snapshotDir) From 8ca4ecb6a56b96bae21b33e27f6abdb53676683a Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Wed, 24 Sep 2014 20:39:09 -0700 Subject: [PATCH 081/315] [SPARK-546] Add full outer join to RDD and DStream. leftOuterJoin and rightOuterJoin are already implemented. This patch adds fullOuterJoin. Author: Aaron Staple Closes #1395 from staple/SPARK-546 and squashes the following commits: 1f5595c [Aaron Staple] Fix python style 7ac0aa9 [Aaron Staple] [SPARK-546] Add full outer join to RDD and DStream. 3b5d137 [Aaron Staple] In JavaPairDStream, make class tag specification in rightOuterJoin consistent with other functions. 31f2956 [Aaron Staple] Fix left outer join documentation comments. --- .../apache/spark/api/java/JavaPairRDD.scala | 48 +++++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 42 +++++++++++++++ .../org/apache/spark/PartitioningSuite.scala | 3 ++ .../spark/rdd/PairRDDFunctionsSuite.scala | 15 ++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 1 + docs/programming-guide.md | 2 +- python/pyspark/join.py | 16 ++++++ python/pyspark/rdd.py | 25 ++++++++- .../streaming/api/java/JavaPairDStream.scala | 54 +++++++++++++++++-- .../dstream/PairDStreamFunctions.scala | 36 +++++++++++++ .../streaming/BasicOperationsSuite.scala | 15 ++++++ 11 files changed, 250 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 880f61c49726e..0846225e4f992 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -469,6 +469,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Uses the given Partitioner to partition the output RDD. + */ + def fullOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) + : JavaPairRDD[K, (Optional[V], Optional[W])] = { + val joinResult = rdd.fullOuterJoin(other, partitioner) + fromRDD(joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + }) + } + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing * partitioner/parallelism level. @@ -563,6 +579,38 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/ + * parallelism level. + */ + def fullOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], Optional[W])] = { + val joinResult = rdd.fullOuterJoin(other) + fromRDD(joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + }) + } + + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions. + */ + def fullOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int) + : JavaPairRDD[K, (Optional[V], Optional[W])] = { + val joinResult = rdd.fullOuterJoin(other, numPartitions) + fromRDD(joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + }) + } + /** * Return the key-value pairs in this RDD to the master as a Map. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 51ba8c2d17834..7f578bc5dac39 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -506,6 +506,23 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Uses the given Partitioner to partition the output RDD. + */ + def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) + : RDD[(K, (Option[V], Option[W]))] = { + this.cogroup(other, partitioner).flatMapValues { + case (vs, Seq()) => vs.map(v => (Some(v), None)) + case (Seq(), ws) => ws.map(w => (None, Some(w))) + case (vs, ws) => for (v <- vs; w <- ws) yield (Some(v), Some(w)) + } + } + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the * existing partitioner/parallelism level. @@ -585,6 +602,31 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) rightOuterJoin(other, new HashPartitioner(numPartitions)) } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/ + * parallelism level. + */ + def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] = { + fullOuterJoin(other, defaultPartitioner(self, other)) + } + + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions. + */ + def fullOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], Option[W]))] = { + fullOuterJoin(other, new HashPartitioner(numPartitions)) + } + /** * Return the key-value pairs in this RDD to the master as a Map. * diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index fc0cee3e8749d..646ede30ae6ff 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -193,11 +193,13 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.fullOuterJoin(grouped4).partitioner === grouped4.partitioner) assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.fullOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) assert(grouped2.map(_ => 1).partitioner === None) @@ -218,6 +220,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.fullOuterJoin(arrPairs) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index e84cc69592339..75b01191901b8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -298,6 +298,21 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { )) } + test("fullOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.fullOuterJoin(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (Some(1), Some('x'))), + (1, (Some(2), Some('x'))), + (2, (Some(1), Some('y'))), + (2, (Some(1), Some('z'))), + (3, (Some(1), None)), + (4, (None, Some('w'))) + )) + } + test("join with no matches") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index c1b501a75c8b8..465c1a8a43a79 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -193,6 +193,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(rdd.join(emptyKv).collect().size === 0) assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) + assert(rdd.fullOuterJoin(emptyKv).collect().size === 2) assert(rdd.cogroup(emptyKv).collect().size === 2) assert(rdd.union(emptyKv).collect().size === 2) } diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 01d378af574b5..510b47a2aaad1 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -906,7 +906,7 @@ for details. diff --git a/python/pyspark/join.py b/python/pyspark/join.py index b0f1cc1927066..b4a844713745a 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -80,6 +80,22 @@ def dispatch(seq): return _do_python_join(rdd, other, numPartitions, dispatch) +def python_full_outer_join(rdd, other, numPartitions): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numPartitions, dispatch) + + def python_cogroup(rdds, numPartitions): def make_mapper(i): return lambda (k, v): (k, (i, v)) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ef233bc80c5c..680140d72d03c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -36,7 +36,7 @@ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer from pyspark.join import python_join, python_left_outer_join, \ - python_right_outer_join, python_cogroup + python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel @@ -1375,7 +1375,7 @@ def leftOuterJoin(self, other, numPartitions=None): For each element (k, v) in C{self}, the resulting RDD will either contain all pairs (k, (v, w)) for w in C{other}, or the pair - (k, (v, None)) if no elements in other have key k. + (k, (v, None)) if no elements in C{other} have key k. Hash-partitions the resulting RDD into the given number of partitions. @@ -1403,6 +1403,27 @@ def rightOuterJoin(self, other, numPartitions=None): """ return python_right_outer_join(self, other, numPartitions) + def fullOuterJoin(self, other, numPartitions=None): + """ + Perform a right outer join of C{self} and C{other}. + + For each element (k, v) in C{self}, the resulting RDD will either + contain all pairs (k, (v, w)) for w in C{other}, or the pair + (k, (v, None)) if no elements in C{other} have key k. + + Similarly, for each element (k, w) in C{other}, the resulting RDD will + either contain all pairs (k, (v, w)) for v in C{self}, or the pair + (k, (None, w)) if no elements in C{self} have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("c", 8)]) + >>> sorted(x.fullOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None)), ('c', (None, 8))] + """ + return python_full_outer_join(self, other, numPartitions) + # TODO: add option to control map-side combining # portable_hash is used as default, because builtin hash of None is different # cross machines. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index c00e11d11910f..59d4423086ef0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -606,8 +606,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. - * The supplied org.apache.spark.Partitioner is used to control the partitioning of each RDD. + * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control + * the partitioning of each RDD. */ def leftOuterJoin[W]( other: JavaPairDStream[K, W], @@ -624,8 +625,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * number of partitions. */ def rightOuterJoin[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (Optional[V], W)] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag val joinResult = dstream.rightOuterJoin(other.dstream) joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)} } @@ -658,6 +658,52 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)} } + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default + * number of partitions. + */ + def fullOuterJoin[W](other: JavaPairDStream[K, W]) + : JavaPairDStream[K, (Optional[V], Optional[W])] = { + implicit val cm: ClassTag[W] = fakeClassTag + val joinResult = dstream.fullOuterJoin(other.dstream) + joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + } + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + */ + def fullOuterJoin[W]( + other: JavaPairDStream[K, W], + numPartitions: Int + ): JavaPairDStream[K, (Optional[V], Optional[W])] = { + implicit val cm: ClassTag[W] = fakeClassTag + val joinResult = dstream.fullOuterJoin(other.dstream, numPartitions) + joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + } + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control + * the partitioning of each RDD. + */ + def fullOuterJoin[W]( + other: JavaPairDStream[K, W], + partitioner: Partitioner + ): JavaPairDStream[K, (Optional[V], Optional[W])] = { + implicit val cm: ClassTag[W] = fakeClassTag + val joinResult = dstream.fullOuterJoin(other.dstream, partitioner) + joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + } + } + /** * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 826bf39e860e1..9467595d307a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -568,6 +568,42 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) ) } + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default + * number of partitions. + */ + def fullOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = { + fullOuterJoin[W](other, defaultPartitioner()) + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + */ + def fullOuterJoin[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int + ): DStream[(K, (Option[V], Option[W]))] = { + fullOuterJoin[W](other, defaultPartitioner(numPartitions)) + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control + * the partitioning of each RDD. + */ + def fullOuterJoin[W: ClassTag]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (Option[V], Option[W]))] = { + self.transformWith( + other, + (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.fullOuterJoin(rdd2, partitioner) + ) + } + /** * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval * is generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 059ac6c2dbee2..6c8bb50145367 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -303,6 +303,21 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData1, inputData2, operation, outputData, true) } + test("fullOuterJoin") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) + val outputData = Seq( + Seq( ("a", (Some(1), Some("x"))), ("b", (Some(1), Some("x"))) ), + Seq( ("", (Some(1), Some("x"))), ("a", (Some(1), None)), ("b", (None, Some("x"))) ), + Seq( ("", (Some(1), None)) ), + Seq( ("", (None, Some("x"))) ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x, 1)).fullOuterJoin(s2.map(x => (x, "x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + test("updateStateByKey") { val inputData = Seq( From b8487713d3bf288a4f6fc149e6ee4cc8196d6e7d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 24 Sep 2014 23:10:26 -0700 Subject: [PATCH 082/315] [SPARK-2778] [yarn] Add yarn integration tests. This patch adds a couple of, currently, very simple integration tests to make sure both client and cluster modes are working. The tests don't do much yet other than run a simple job, but the plan is to enhance them after we get the framework in. The cluster tests are noisy, so redirect all log output to a file like other tests do. Copying the conf around sucks but it's less work than messing with maven/sbt and having to clean up other projects. Note the test is only added for yarn-stable. The code compiles against yarn-alpha but there are two issues I ran into that I could not overcome: - an old netty dependency kept creeping into the classpath and causing akka to not work, when using sbt; the old netty was correctly suppressed under maven. - MiniYARNCluster kept failing to execute containers because it did not create the NM's local dir itself; this is apparently a known behavior, but I'm not sure how to work around it. None of those issues are present with the stable Yarn. Also, these tests are a little slow to run. Apparently Spark doesn't yet tag tests (so that these could be isolated in a "slow" batch), so this is something to keep in mind. Author: Marcelo Vanzin Closes #2257 from vanzin/yarn-tests and squashes the following commits: 6d5b84e [Marcelo Vanzin] Fix wrong system property being set. 8b0933d [Marcelo Vanzin] Merge branch 'master' into yarn-tests 5c2b56f [Marcelo Vanzin] Use custom log4j conf for Yarn containers. ec73f17 [Marcelo Vanzin] More review feedback. 67f5b02 [Marcelo Vanzin] Review feedback. f01517c [Marcelo Vanzin] Review feedback. 68fbbbf [Marcelo Vanzin] Use older constructor available in older Hadoop releases. d07ef9a [Marcelo Vanzin] Merge branch 'master' into yarn-tests add8416 [Marcelo Vanzin] [SPARK-2778] [yarn] Add yarn integration tests. --- pom.xml | 31 +++- .../spark/deploy/yarn/ApplicationMaster.scala | 10 +- .../apache/spark/deploy/yarn/ClientBase.scala | 2 +- .../deploy/yarn/ExecutorRunnableUtil.scala | 2 +- yarn/pom.xml | 3 +- yarn/stable/pom.xml | 9 + .../src/test/resources/log4j.properties | 28 ++++ .../spark/deploy/yarn/YarnClusterSuite.scala | 154 ++++++++++++++++++ 8 files changed, 229 insertions(+), 10 deletions(-) create mode 100644 yarn/stable/src/test/resources/log4j.properties create mode 100644 yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala diff --git a/pom.xml b/pom.xml index 520aed3806937..f3de097b9cb32 100644 --- a/pom.xml +++ b/pom.xml @@ -712,6 +712,35 @@ + + org.apache.hadoop + hadoop-yarn-server-tests + ${yarn.version} + tests + test + + + asm + asm + + + org.ow2.asm + asm + + + org.jboss.netty + netty + + + javax.servlet + servlet-api + + + commons-logging + commons-logging + + + org.apache.hadoop hadoop-yarn-server-web-proxy @@ -1187,7 +1216,7 @@ org.apache.zookeeper zookeeper - 3.4.5-mapr-1406 + 3.4.5-mapr-1406 diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9050808157257..b51daeb437516 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -401,17 +401,17 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. status = FinalApplicationStatus.SUCCEEDED } catch { - case e: InvocationTargetException => { + case e: InvocationTargetException => e.getCause match { - case _: InterruptedException => { + case _: InterruptedException => // Reporter thread can interrupt to stop user class - } + + case e => throw e } - } } finally { logDebug("Finishing main") + finalStatus = status } - finalStatus = status } } userClassThread.setName("Driver") diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 4870b0cb3ddaf..1cf19c198509c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -348,7 +348,7 @@ private[spark] trait ClientBase extends Logging { } // For log4j configuration to reference - javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) val userClass = if (args.userClass != null) { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index bbbf615510762..d7a7175d5e578 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -98,7 +98,7 @@ trait ExecutorRunnableUtil extends Logging { */ // For log4j configuration to reference - javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server", diff --git a/yarn/pom.xml b/yarn/pom.xml index 815a736c2e8fd..8a7035c85e9f1 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -140,7 +140,6 @@ ${basedir}/../.. - ${spark.classpath} @@ -148,7 +147,7 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - + ../common/src/main/resources diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index fd934b7726181..97eb0548e77c3 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -32,4 +32,13 @@ jar Spark Project YARN Stable API + + + org.apache.hadoop + hadoop-yarn-server-tests + tests + test + + + diff --git a/yarn/stable/src/test/resources/log4j.properties b/yarn/stable/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..26b73a1b39744 --- /dev/null +++ b/yarn/stable/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala new file mode 100644 index 0000000000000..857a4447dd738 --- /dev/null +++ b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} + +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster + +import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils + +class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { + + // log4j configuration for the Yarn containers, so that their output is collected + // by Yarn instead of trying to overwrite unit-tests.log. + private val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + private var tempDir: File = _ + private var fakeSparkJar: File = _ + private var oldConf: Map[String, String] = _ + + override def beforeAll() { + tempDir = Utils.createTempDir() + + val logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8) + + val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator + + sys.props("java.class.path") + + oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(new YarnConfiguration()) + yarnCluster.start() + yarnCluster.getConfig().foreach { e => + sys.props += ("spark.hadoop." + e.getKey() -> e.getValue()) + } + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath())) + sys.props += ("spark.executor.instances" -> "1") + sys.props += ("spark.driver.extraClassPath" -> childClasspath) + sys.props += ("spark.executor.extraClassPath" -> childClasspath) + + super.beforeAll() + } + + override def afterAll() { + yarnCluster.stop() + sys.props.retain { case (k, v) => !k.startsWith("spark.") } + sys.props ++= oldConf + super.afterAll() + } + + test("run Spark in yarn-client mode") { + var result = File.createTempFile("result", null, tempDir) + YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) + checkResult(result) + } + + test("run Spark in yarn-cluster mode") { + val main = YarnClusterDriver.getClass.getName().stripSuffix("$") + var result = File.createTempFile("result", null, tempDir) + + // The Client object will call System.exit() after the job is done, and we don't want + // that because it messes up the scalatest monitoring. So replicate some of what main() + // does here. + val args = Array("--class", main, + "--jar", "file:" + fakeSparkJar.getAbsolutePath(), + "--arg", "yarn-cluster", + "--arg", result.getAbsolutePath(), + "--num-executors", "1") + val sparkConf = new SparkConf() + val yarnConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val clientArgs = new ClientArguments(args, sparkConf) + new Client(clientArgs, yarnConf, sparkConf).run() + checkResult(result) + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + private def checkResult(result: File) = { + var resultString = Files.toString(result, Charsets.UTF_8) + resultString should be ("success") + } + +} + +private object YarnClusterDriver extends Logging with Matchers { + + def main(args: Array[String]) = { + if (args.length != 2) { + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: YarnClusterDriver [master] [result file] + """.stripMargin) + System.exit(1) + } + + val sc = new SparkContext(new SparkConf().setMaster(args(0)) + .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) + val status = new File(args(1)) + var result = "failure" + try { + val data = sc.parallelize(1 to 4, 4).collect().toSet + data should be (Set(1, 2, 3, 4)) + result = "success" + } finally { + sc.stop() + Files.write(result, status, Charsets.UTF_8) + } + } + +} From c3f2a8588e19aab814ac5cdd86575bb5558d5e46 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 25 Sep 2014 23:20:17 +0530 Subject: [PATCH 083/315] SPARK-2932 [STREAMING] Move MasterFailureTest out of "main" source directory (HT @vanzin) Whatever the reason was for having this test class in `main`, if there is one, appear to be moot. This may have been a result of earlier streaming test reorganization. This simply puts `MasterFailureTest` back under `test/`, removes some redundant copied code, and touches up a few tiny inspection warnings along the way. Author: Sean Owen Closes #2399 from srowen/SPARK-2932 and squashes the following commits: 3909411 [Sean Owen] Move MasterFailureTest to src/test, and remove redundant TestOutputStream --- .../apache/spark/streaming/FailureSuite.scala | 1 - .../spark/streaming}/MasterFailureTest.scala | 43 ++++--------------- 2 files changed, 8 insertions(+), 36 deletions(-) rename streaming/src/{main/scala/org/apache/spark/streaming/util => test/scala/org/apache/spark/streaming}/MasterFailureTest.scala (91%) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 92e1b76d28301..40434b1f9b709 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming import org.apache.spark.Logging -import org.apache.spark.streaming.util.MasterFailureTest import org.apache.spark.util.Utils import java.io.File diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala similarity index 91% rename from streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala rename to streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 98e17ff92e205..c53c01706083a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -15,20 +15,18 @@ * limitations under the License. */ -package org.apache.spark.streaming.util +package org.apache.spark.streaming import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils -import StreamingContext._ +import org.apache.spark.streaming.StreamingContext._ import scala.util.Random -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import java.io.{File, ObjectInputStream, IOException} +import java.io.{File, IOException} import java.nio.charset.Charset import java.util.UUID @@ -91,7 +89,7 @@ object MasterFailureTest extends Logging { // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... val input = (1 to numBatches).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... - val expectedOutput = (1L to numBatches).map(i => (1L to i).reduce(_ + _)).map(j => ("a", j)) + val expectedOutput = (1L to numBatches).map(i => (1L to i).sum).map(j => ("a", j)) val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Long], state: Option[Long]) => { @@ -218,7 +216,7 @@ object MasterFailureTest extends Logging { while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer - val outputBuffer = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[T]].output + val outputBuffer = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[T]].output def output = outputBuffer.flatMap(x => x) // Start the thread to kill the streaming after some time @@ -239,7 +237,7 @@ object MasterFailureTest extends Logging { while (!killed && !isLastOutputGenerated && !isTimedOut) { Thread.sleep(100) timeRan = System.currentTimeMillis() - startTime - isLastOutputGenerated = (!output.isEmpty && output.last == lastExpectedOutput) + isLastOutputGenerated = (output.nonEmpty && output.last == lastExpectedOutput) isTimedOut = (timeRan + totalTimeRan > maxTimeToRun) } } catch { @@ -313,31 +311,6 @@ object MasterFailureTest extends Logging { } } -/** - * This is a output stream just for testing. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. - */ -private[streaming] -class TestOutputStream[T: ClassTag]( - parent: DStream[T], - val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] - ) extends ForEachDStream[T]( - parent, - (rdd: RDD[T], t: Time) => { - val collected = rdd.collect() - output += collected - } - ) { - - // This is to clear the output buffer every it is read from a checkpoint - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - output.clear() - } -} - - /** * Thread to kill streaming context after a random period of time. */ From 9b56e249e09d8da20f703b9381c5c3c8a1a1d4a9 Mon Sep 17 00:00:00 2001 From: epahomov Date: Thu, 25 Sep 2014 14:50:12 -0700 Subject: [PATCH 084/315] [SPARK-3690] Closing shuffle writers we swallow more important exception Author: epahomov Closes #2537 from epahomov/SPARK-3690 and squashes the following commits: a0b7de4 [epahomov] [SPARK-3690] Closing shuffle writers we swallow more important exception --- .../org/apache/spark/scheduler/ShuffleMapTask.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 381eff2147e95..a98ee118254a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -69,8 +69,13 @@ private[spark] class ShuffleMapTask( return writer.stop(success = true).get } catch { case e: Exception => - if (writer != null) { - writer.stop(success = false) + try { + if (writer != null) { + writer.stop(success = false) + } + } catch { + case e: Exception => + log.debug("Could not stop writer", e) } throw e } finally { From ff637c9380a6342fd0a4dde0710ec23856751dd4 Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Thu, 25 Sep 2014 16:11:00 -0700 Subject: [PATCH 085/315] [SPARK-1484][MLLIB] Warn when running an iterative algorithm on uncached data. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add warnings to KMeans, GeneralizedLinearAlgorithm, and computeSVD when called with input data that is not cached. KMeans is implemented iteratively, and I believe that GeneralizedLinearAlgorithm’s current optimizers are iterative and its future optimizers are also likely to be iterative. RowMatrix’s computeSVD is iterative against an RDD when run in DistARPACK mode. ALS and DecisionTree are iterative as well, but they implement RDD caching internally so do not require a warning. I added a warning to GeneralizedLinearAlgorithm rather than inside its optimizers, where the iteration actually occurs, because internally GeneralizedLinearAlgorithm maps its input data to an uncached RDD before passing it to an optimizer. (In other words, the warning would be printed for every GeneralizedLinearAlgorithm run, regardless of whether its input is cached, if the warning were in GradientDescent or other optimizer.) I assume that use of an uncached RDD by GeneralizedLinearAlgorithm is intentional, and that the mapping there (adding label, intercepts and scaling) is a lightweight operation. Arguably a user calling an optimizer such as GradientDescent will be knowledgable enough to cache their data without needing a log warning, so lack of a warning in the optimizers may be ok. Some of the documentation examples making use of these iterative algorithms did not cache their training RDDs (while others did). I updated the examples to always cache. I also fixed some (unrelated) minor errors in the documentation examples. Author: Aaron Staple Closes #2347 from staple/SPARK-1484 and squashes the following commits: bd49701 [Aaron Staple] Address review comments. ab2d4a4 [Aaron Staple] Disable warnings on python code path. a7a0f99 [Aaron Staple] Change code comments per review comments. 7cca1dc [Aaron Staple] Change warning message text. c77e939 [Aaron Staple] [SPARK-1484][MLLIB] Warn when running an iterative algorithm on uncached data. 3b6c511 [Aaron Staple] Minor doc example fixes. --- docs/mllib-clustering.md | 3 +- docs/mllib-linear-methods.md | 9 ++-- docs/mllib-optimization.md | 1 + .../mllib/api/python/PythonMLLibAPI.scala | 54 ++++++++++--------- .../spark/mllib/clustering/KMeans.scala | 22 ++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 11 ++++ .../GeneralizedLinearAlgorithm.scala | 21 ++++++++ 7 files changed, 91 insertions(+), 30 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dfd9cd572888c..d10bd63746629 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -52,7 +52,7 @@ import org.apache.spark.mllib.linalg.Vectors // Load and parse the data val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() // Cluster the data into two classes using KMeans val numClusters = 2 @@ -100,6 +100,7 @@ public class KMeansExample { } } ); + parsedData.cache(); // Cluster the data into two classes using KMeans int numClusters = 2; diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 9137f9dc1b692..d31bec3e1bd01 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -396,7 +396,7 @@ val data = sc.textFile("data/mllib/ridge-data/lpsa.data") val parsedData = data.map { line => val parts = line.split(',') LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -} +}.cache() // Building the model val numIterations = 100 @@ -455,6 +455,7 @@ public class LinearRegression { } } ); + parsedData.cache(); // Building the model int numIterations = 100; @@ -470,7 +471,7 @@ public class LinearRegression { } } ); - JavaRDD MSE = new JavaDoubleRDD(valuesAndPreds.map( + double MSE = new JavaDoubleRDD(valuesAndPreds.map( new Function, Object>() { public Object call(Tuple2 pair) { return Math.pow(pair._1() - pair._2(), 2.0); @@ -553,8 +554,8 @@ but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} -val trainingData = ssc.textFileStream('/training/data/dir').map(LabeledPoint.parse) -val testData = ssc.textFileStream('/testing/data/dir').map(LabeledPoint.parse) +val trainingData = ssc.textFileStream("/training/data/dir").map(LabeledPoint.parse).cache() +val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) {% endhighlight %} diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 26ce5f3c501ff..45141c235be90 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -217,6 +217,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val numFeatures = data.take(1)(0).features.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9164c294ac7b8..e9f41758581e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -67,11 +67,13 @@ class PythonMLLibAPI extends Serializable { MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) private def trainRegressionModel( - trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, + learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] - val model = trainFunc(data.rdd, initialWeights) + // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. + learner.disableUncachedWarning() + val model = learner.run(data.rdd, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(SerDe.dumps(model.weights)) ret.add(model.intercept: java.lang.Double) @@ -106,8 +108,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - lrAlg.run(data, initialWeights), + lrAlg, data, initialWeightsBA) } @@ -122,15 +123,14 @@ class PythonMLLibAPI extends Serializable { regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + val lassoAlg = new LassoWithSGD() + lassoAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) trainRegressionModel( - (data, initialWeights) => - LassoWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + lassoAlg, data, initialWeightsBA) } @@ -145,15 +145,14 @@ class PythonMLLibAPI extends Serializable { regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) trainRegressionModel( - (data, initialWeights) => - RidgeRegressionWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + ridgeAlg, data, initialWeightsBA) } @@ -186,8 +185,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - SVMAlg.run(data, initialWeights), + SVMAlg, data, initialWeightsBA) } @@ -220,8 +218,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - LogRegAlg.run(data, initialWeights), + LogRegAlg, data, initialWeightsBA) } @@ -249,7 +246,14 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): KMeansModel = { - KMeans.train(data.rdd, k, maxIterations, runs, initializationMode) + val kMeansAlg = new KMeans() + .setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. + .disableUncachedWarning() + return kMeansAlg.run(data.rdd) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fce8fe29f6e40..7443f232ec3e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom /** @@ -112,11 +113,26 @@ class KMeans private ( this } + /** Whether a warning should be logged if the input RDD is uncached. */ + private var warnOnUncachedInput = true + + /** Disable warnings about uncached input. */ + private[spark] def disableUncachedWarning(): this.type = { + warnOnUncachedInput = false + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { + + if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + // Compute squared norms and cache them. val norms = data.map(v => breezeNorm(v.toBreeze, 2.0)) norms.persist() @@ -125,6 +141,12 @@ class KMeans private ( } val model = runBreeze(breezeData) norms.unpersist() + + // Warn at the end of the run as well, for increased visibility. + if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } model } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 2e414a73be8e0..4174f45d231c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: @@ -231,6 +232,10 @@ class RowMatrix( val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) case SVDMode.DistARPACK => + if (rows.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.") EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter) } @@ -256,6 +261,12 @@ class RowMatrix( logWarning(s"Requested $k singular values but only found $sk nonzeros.") } + // Warn at the end of the run as well, for increased visibility. + if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk)) val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 20c1fdd2269ce..d0fe4179685ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -133,6 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } + /** Whether a warning should be logged if the input RDD is uncached. */ + private var warnOnUncachedInput = true + + /** Disable warnings about uncached input. */ + private[spark] def disableUncachedWarning(): this.type = { + warnOnUncachedInput = false + this + } + /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -149,6 +159,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { + if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + // Check the data properties before running the optimizer if (validateData && !validators.forall(func => func(input))) { throw new SparkException("Input validation failed.") @@ -223,6 +238,12 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] weights = scaler.transform(weights) } + // Warn at the end of the run as well, for increased visibility. + if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + createModel(weights, intercept) } } From 0dc868e787a3bc69c1b8e90d916a6dcea8dbcd6d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 25 Sep 2014 16:49:15 -0700 Subject: [PATCH 086/315] [SPARK-3584] sbin/slaves doesn't work when we use password authentication for SSH Author: Kousuke Saruta Closes #2444 from sarutak/slaves-scripts-modification and squashes the following commits: eff7394 [Kousuke Saruta] Improve the description about Cluster Launch Script in docs/spark-standalone.md 7858225 [Kousuke Saruta] Modified sbin/slaves to use the environment variable "SPARK_SSH_FOREGROUND" as a flag 53d7121 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into slaves-scripts-modification e570431 [Kousuke Saruta] Added a description for SPARK_SSH_FOREGROUND variable 7120a0c [Kousuke Saruta] Added a description about default host for sbin/slaves 1bba8a9 [Kousuke Saruta] Added SPARK_SSH_FOREGROUND flag to sbin/slaves 88e2f17 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into slaves-scripts-modification 297e75d [Kousuke Saruta] Modified sbin/slaves not to export HOSTLIST --- .gitignore | 1 + .rat-excludes | 1 + conf/{slaves => slaves.template} | 0 docs/spark-standalone.md | 7 ++++++- sbin/slaves.sh | 31 ++++++++++++++++++++++--------- 5 files changed, 30 insertions(+), 10 deletions(-) rename conf/{slaves => slaves.template} (100%) diff --git a/.gitignore b/.gitignore index 7779980b74a22..34939e3a97aaa 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ conf/*.cmd conf/*.properties conf/*.conf conf/*.xml +conf/slaves docs/_site docs/api target/ diff --git a/.rat-excludes b/.rat-excludes index 9fc99d7fca35d..b14ad53720f32 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -19,6 +19,7 @@ log4j.properties log4j.properties.template metrics.properties.template slaves +slaves.template spark-env.sh spark-env.cmd spark-env.sh.template diff --git a/conf/slaves b/conf/slaves.template similarity index 100% rename from conf/slaves rename to conf/slaves.template diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 29b5491861bf3..58103fab20819 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -62,7 +62,12 @@ Finally, the following configuration options can be passed to the master and wor # Cluster Launch Scripts -To launch a Spark standalone cluster with the launch scripts, you need to create a file called `conf/slaves` in your Spark directory, which should contain the hostnames of all the machines where you would like to start Spark workers, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing, you can just put `localhost` in this file. +To launch a Spark standalone cluster with the launch scripts, you should create a file called conf/slaves in your Spark directory, +which must contain the hostnames of all the machines where you intend to start Spark workers, one per line. +If conf/slaves does not exist, the launch scripts defaults to a single machine (localhost), which is useful for testing. +Note, the master machine accesses each of the worker machines via ssh. By default, ssh is run in parallel and requires password-less (using a private key) access to be setup. +If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. + Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: diff --git a/sbin/slaves.sh b/sbin/slaves.sh index 1d4dc5edf9858..cdad47ee2e594 100755 --- a/sbin/slaves.sh +++ b/sbin/slaves.sh @@ -44,7 +44,9 @@ sbin="`cd "$sbin"; pwd`" # If the slaves file is specified in the command line, # then it takes precedence over the definition in # spark-env.sh. Save it here. -HOSTLIST="$SPARK_SLAVES" +if [ -f "$SPARK_SLAVES" ]; then + HOSTLIST=`cat "$SPARK_SLAVES"` +fi # Check if --config is passed as an argument. It is an optional parameter. # Exit if the argument is not a directory. @@ -67,23 +69,34 @@ fi if [ "$HOSTLIST" = "" ]; then if [ "$SPARK_SLAVES" = "" ]; then - export HOSTLIST="${SPARK_CONF_DIR}/slaves" + if [ -f "${SPARK_CONF_DIR}/slaves" ]; then + HOSTLIST=`cat "${SPARK_CONF_DIR}/slaves"` + else + HOSTLIST=localhost + fi else - export HOSTLIST="${SPARK_SLAVES}" + HOSTLIST=`cat "${SPARK_SLAVES}"` fi fi + + # By default disable strict host key checking if [ "$SPARK_SSH_OPTS" = "" ]; then SPARK_SSH_OPTS="-o StrictHostKeyChecking=no" fi -for slave in `cat "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do - ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ - 2>&1 | sed "s/^/$slave: /" & - if [ "$SPARK_SLAVE_SLEEP" != "" ]; then - sleep $SPARK_SLAVE_SLEEP - fi +for slave in `echo "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do + if [ -n "${SPARK_SSH_FOREGROUND}" ]; then + ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ + 2>&1 | sed "s/^/$slave: /" + else + ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ + 2>&1 | sed "s/^/$slave: /" & + fi + if [ "$SPARK_SLAVE_SLEEP" != "" ]; then + sleep $SPARK_SLAVE_SLEEP + fi done wait From 86bce764983f2b14e1bd87fc3f4f938f7a217e1b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 25 Sep 2014 18:24:01 -0700 Subject: [PATCH 087/315] SPARK-2634: Change MapOutputTrackerWorker.mapStatuses to ConcurrentHashMap MapOutputTrackerWorker.mapStatuses is used concurrently, it should be thread-safe. This bug has already been fixed in #1328. Nevertheless, considering #1328 won't be merged soon, I send this trivial fix and hope this issue can be solved soon. Author: zsxwing Closes #1541 from zsxwing/SPARK-2634 and squashes the following commits: d450053 [zsxwing] SPARK-2634: Change MapOutputTrackerWorker.mapStatuses to ConcurrentHashMap --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 51705c895a55c..f92189b707fb5 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,10 +18,12 @@ package org.apache.spark import java.io._ +import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.{HashSet, HashMap, Map} import scala.concurrent.Await +import scala.collection.JavaConversions._ import akka.actor._ import akka.pattern.ask @@ -84,6 +86,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks. * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the * master's corresponding HashMap. + * + * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a + * thread-safe map. */ protected val mapStatuses: Map[Int, Array[MapStatus]] @@ -339,7 +344,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) * MapOutputTrackerMaster. */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses = new HashMap[Int, Array[MapStatus]] + protected val mapStatuses: Map[Int, Array[MapStatus]] = + new ConcurrentHashMap[Int, Array[MapStatus]] } private[spark] object MapOutputTracker { From b235e013638685758885842dc3268e9800af3678 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Thu, 25 Sep 2014 22:56:43 -0700 Subject: [PATCH 088/315] [SPARK-3686][STREAMING] Wait for sink to commit the channel before check... ...ing for the channel size. Author: Hari Shreedharan Closes #2531 from harishreedharan/sparksinksuite-fix and squashes the following commits: 30393c1 [Hari Shreedharan] Use more deterministic method to figure out when batches come in. 6ce9d8b [Hari Shreedharan] [SPARK-3686][STREAMING] Wait for sink to commit the channel before checking for the channel size. --- .../flume/sink/SparkAvroCallbackHandler.scala | 14 +++++++++++- .../streaming/flume/sink/SparkSink.scala | 10 +++++++++ .../flume/sink/TransactionProcessor.scala | 12 ++++++++++ .../streaming/flume/sink/SparkSinkSuite.scala | 22 +++++++++++-------- 4 files changed, 48 insertions(+), 10 deletions(-) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index e77cf7bfa54d0..3c656a381bd9b 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.streaming.flume.sink -import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.concurrent.{CountDownLatch, ConcurrentHashMap, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConversions._ @@ -58,8 +58,12 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha private val seqBase = RandomStringUtils.randomAlphanumeric(8) private val seqCounter = new AtomicLong(0) + @volatile private var stopped = false + @volatile private var isTest = false + private var testLatch: CountDownLatch = null + /** * Returns a bunch of events to Spark over Avro RPC. * @param n Maximum number of events to return in a batch @@ -90,6 +94,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha val processor = new TransactionProcessor( channel, seq, n, transactionTimeout, backOffInterval, this) sequenceNumberToProcessor.put(seq, processor) + if (isTest) { + processor.countDownWhenBatchAcked(testLatch) + } Some(processor) } else { None @@ -141,6 +148,11 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha } } + private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { + testLatch = latch + isTest = true + } + /** * Shuts down the executor used to process transactions. */ diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 98ae7d783aec8..14dffb15fef98 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -138,6 +138,16 @@ class SparkSink extends AbstractSink with Logging with Configurable { throw new RuntimeException("Server was not started!") ) } + + /** + * Pass in a [[CountDownLatch]] for testing purposes. This batch is counted down when each + * batch is received. The test can simply call await on this latch till the expected number of + * batches are received. + * @param latch + */ + private[flume] def countdownWhenBatchReceived(latch: CountDownLatch) { + handler.foreach(_.countDownWhenBatchAcked(latch)) + } } /** diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index 13f3aa94be414..ea45b14294df9 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -62,6 +62,10 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, @volatile private var stopped = false + @volatile private var isTest = false + + private var testLatch: CountDownLatch = null + // The transaction that this processor would handle var txOpt: Option[Transaction] = None @@ -182,6 +186,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, rollbackAndClose(tx, close = false) // tx will be closed later anyway } finally { tx.close() + if (isTest) { + testLatch.countDown() + } } } else { logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") @@ -237,4 +244,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, processAckOrNack() null } + + private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { + testLatch = latch + isTest = true + } } diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 75a6668c6210b..a2b2cc6149d95 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -38,7 +38,7 @@ class SparkSinkSuite extends FunSuite { val channelCapacity = 5000 test("Success with ack") { - val (channel, sink) = initializeChannelAndSink() + val (channel, sink, latch) = initializeChannelAndSink() channel.start() sink.start() @@ -51,6 +51,7 @@ class SparkSinkSuite extends FunSuite { val events = client.getEventBatch(1000) client.ack(events.getSequenceNumber) assert(events.getEvents.size() === 1000) + latch.await(1, TimeUnit.SECONDS) assertChannelIsEmpty(channel) sink.stop() channel.stop() @@ -58,7 +59,7 @@ class SparkSinkSuite extends FunSuite { } test("Failure with nack") { - val (channel, sink) = initializeChannelAndSink() + val (channel, sink, latch) = initializeChannelAndSink() channel.start() sink.start() putEvents(channel, eventsPerBatch) @@ -70,6 +71,7 @@ class SparkSinkSuite extends FunSuite { val events = client.getEventBatch(1000) assert(events.getEvents.size() === 1000) client.nack(events.getSequenceNumber) + latch.await(1, TimeUnit.SECONDS) assert(availableChannelSlots(channel) === 4000) sink.stop() channel.stop() @@ -77,7 +79,7 @@ class SparkSinkSuite extends FunSuite { } test("Failure with timeout") { - val (channel, sink) = initializeChannelAndSink(Map(SparkSinkConfig + val (channel, sink, latch) = initializeChannelAndSink(Map(SparkSinkConfig .CONF_TRANSACTION_TIMEOUT -> 1.toString)) channel.start() sink.start() @@ -88,7 +90,7 @@ class SparkSinkSuite extends FunSuite { val (transceiver, client) = getTransceiverAndClient(address, 1)(0) val events = client.getEventBatch(1000) assert(events.getEvents.size() === 1000) - Thread.sleep(1000) + latch.await(1, TimeUnit.SECONDS) assert(availableChannelSlots(channel) === 4000) sink.stop() channel.stop() @@ -106,7 +108,7 @@ class SparkSinkSuite extends FunSuite { def testMultipleConsumers(failSome: Boolean): Unit = { implicit val executorContext = ExecutionContext .fromExecutorService(Executors.newFixedThreadPool(5)) - val (channel, sink) = initializeChannelAndSink() + val (channel, sink, latch) = initializeChannelAndSink(Map.empty, 5) channel.start() sink.start() (1 to 5).foreach(_ => putEvents(channel, eventsPerBatch)) @@ -136,7 +138,7 @@ class SparkSinkSuite extends FunSuite { } }) batchCounter.await() - TimeUnit.SECONDS.sleep(1) // Allow the sink to commit the transactions. + latch.await(1, TimeUnit.SECONDS) executorContext.shutdown() if(failSome) { assert(availableChannelSlots(channel) === 3000) @@ -148,8 +150,8 @@ class SparkSinkSuite extends FunSuite { transceiversAndClients.foreach(x => x._1.close()) } - private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty): (MemoryChannel, - SparkSink) = { + private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty, + batchCounter: Int = 1): (MemoryChannel, SparkSink, CountDownLatch) = { val channel = new MemoryChannel() val channelContext = new Context() @@ -165,7 +167,9 @@ class SparkSinkSuite extends FunSuite { sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString) sink.configure(sinkContext) sink.setChannel(channel) - (channel, sink) + val latch = new CountDownLatch(batchCounter) + sink.countdownWhenBatchReceived(latch) + (channel, sink, latch) } private def putEvents(ch: MemoryChannel, count: Int): Unit = { From 1aa549ba9839565274a12c52fa1075b424f138a6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 26 Sep 2014 09:27:42 -0700 Subject: [PATCH 089/315] [SPARK-3478] [PySpark] Profile the Python tasks This patch add profiling support for PySpark, it will show the profiling results before the driver exits, here is one example: ``` ============================================================ Profile of RDD ============================================================ 5146507 function calls (5146487 primitive calls) in 71.094 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 5144576 68.331 0.000 68.331 0.000 statcounter.py:44(merge) 20 2.735 0.137 71.071 3.554 statcounter.py:33(__init__) 20 0.017 0.001 0.017 0.001 {cPickle.dumps} 1024 0.003 0.000 0.003 0.000 t.py:16() 20 0.001 0.000 0.001 0.000 {reduce} 21 0.001 0.000 0.001 0.000 {cPickle.loads} 20 0.001 0.000 0.001 0.000 copy_reg.py:95(_slotnames) 41 0.001 0.000 0.001 0.000 serializers.py:461(read_int) 40 0.001 0.000 0.002 0.000 serializers.py:179(_batched) 62 0.000 0.000 0.000 0.000 {method 'read' of 'file' objects} 20 0.000 0.000 71.072 3.554 rdd.py:863() 20 0.000 0.000 0.001 0.000 serializers.py:198(load_stream) 40/20 0.000 0.000 71.072 3.554 rdd.py:2093(pipeline_func) 41 0.000 0.000 0.002 0.000 serializers.py:130(load_stream) 40 0.000 0.000 71.072 1.777 rdd.py:304(func) 20 0.000 0.000 71.094 3.555 worker.py:82(process) ``` Also, use can show profile result manually by `sc.show_profiles()` or dump it into disk by `sc.dump_profiles(path)`, such as ```python >>> sc._conf.set("spark.python.profile", "true") >>> rdd = sc.parallelize(range(100)).map(str) >>> rdd.count() 100 >>> sc.show_profiles() ============================================================ Profile of RDD ============================================================ 284 function calls (276 primitive calls) in 0.001 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 4 0.000 0.000 0.000 0.000 serializers.py:198(load_stream) 4 0.000 0.000 0.000 0.000 {reduce} 12/4 0.000 0.000 0.001 0.000 rdd.py:2092(pipeline_func) 4 0.000 0.000 0.000 0.000 {cPickle.loads} 4 0.000 0.000 0.000 0.000 {cPickle.dumps} 104 0.000 0.000 0.000 0.000 rdd.py:852() 8 0.000 0.000 0.000 0.000 serializers.py:461(read_int) 12 0.000 0.000 0.000 0.000 rdd.py:303(func) ``` The profiling is disabled by default, can be enabled by "spark.python.profile=true". Also, users can dump the results into disks automatically for future analysis, by "spark.python.profile.dump=path_to_dump" Author: Davies Liu Closes #2351 from davies/profiler and squashes the following commits: 7ef2aa0 [Davies Liu] bugfix, add tests for show_profiles and dump_profiles() 2b0daf2 [Davies Liu] fix docs 7a56c24 [Davies Liu] bugfix cba9463 [Davies Liu] move show_profiles and dump_profiles to SparkContext fb9565b [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 116d52a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 09d02c3 [Davies Liu] Merge branch 'master' into profiler c23865c [Davies Liu] Merge branch 'master' into profiler 15d6f18 [Davies Liu] add docs for two configs dadee1a [Davies Liu] add docs string and clear profiles after show or dump 4f8309d [Davies Liu] address comment, add tests 0a5b6eb [Davies Liu] fix Python UDF 4b20494 [Davies Liu] add profile for python --- docs/configuration.md | 19 +++++++++++++++++ python/pyspark/accumulators.py | 15 +++++++++++++ python/pyspark/context.py | 39 +++++++++++++++++++++++++++++++++- python/pyspark/rdd.py | 10 +++++++-- python/pyspark/sql.py | 2 +- python/pyspark/tests.py | 30 ++++++++++++++++++++++++++ python/pyspark/worker.py | 19 ++++++++++++++--- 7 files changed, 127 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index a6dd7245e1552..791b6f2aa3261 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. + + + + + + + + + diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8d..b8cdbbe3cf2b6 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,6 +215,21 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8e7b00469e246..abeda19b77d8b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,6 +20,7 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile +import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -30,7 +31,6 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel -from pyspark import rdd from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -192,6 +192,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() + # profiling stats collected for each PythonRDD + self._profile_stats = [] + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization @@ -792,6 +795,40 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) + def _add_profile(self, id, profileAcc): + if not self._profile_stats: + dump_path = self._conf.get("spark.python.profile.dump") + if dump_path: + atexit.register(self.dump_profiles, dump_path) + else: + atexit.register(self.show_profiles) + + self._profile_stats.append([id, profileAcc, False]) + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, acc, showed) in enumerate(self._profile_stats): + stats = acc.value + if not showed and stats: + print "=" * 60 + print "Profile of RDD" % id + print "=" * 60 + stats.sort_stats("tottime", "cumtime").print_stats() + # mark it as showed + self._profile_stats[i][2] = True + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` + """ + if not os.path.exists(path): + os.makedirs(path) + for id, acc, _ in self._profile_stats: + stats = acc.value + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + self._profile_stats = [] + def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 680140d72d03c..8ed89e2f9769f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -15,7 +15,6 @@ # limitations under the License. # -from base64 import standard_b64encode as b64enc import copy from collections import defaultdict from itertools import chain, ifilter, imap @@ -32,6 +31,7 @@ from random import Random from math import sqrt, log, isinf, isnan +from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -2080,7 +2080,9 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - command = (self.func, self._prev_jrdd_deserializer, + enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" + profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None + command = (self.func, profileStats, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2102,6 +2104,10 @@ def _jrdd(self): self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() + + if enable_profile: + self._id = self._jrdd_val.id() + self.ctx._add_profile(self._id, profileStats) return self._jrdd_val def id(self): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 653195ea438cf..ee5bda8bb43d5 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -974,7 +974,7 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, + command = (func, None, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index d1bb2033b7a16..e6002afa9c70d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -632,6 +632,36 @@ def test_distinct(self): self.assertEquals(result.count(), 3) +class TestProfiler(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + + def test_profiler(self): + + def heavy_foo(x): + for i in range(1 << 20): + x = 1 + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + profiles = self.sc._profile_stats + self.assertEqual(1, len(profiles)) + id, acc, _ = profiles[0] + stats = acc.value + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue("heavy_foo" in func_names) + + self.sc.show_profiles() + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + + class TestSQL(PySparkTestCase): def setUp(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c1f6e3e4a1f40..8257dddfee1c3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,8 @@ import time import socket import traceback +import cProfile +import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -90,10 +92,21 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, deserializer, serializer) = command + (func, stats, deserializer, serializer) = command init_time = time.time() - iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) + + def process(): + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) + + if stats: + p = cProfile.Profile() + p.runcall(process) + st = pstats.Stats(p) + st.stream = None # make it picklable + stats.add(st.strip_dirs()) + else: + process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) From d16e161d744b27291fd2ee7e3578917ee14d83f9 Mon Sep 17 00:00:00 2001 From: aniketbhatnagar Date: Fri, 26 Sep 2014 09:47:58 -0700 Subject: [PATCH 090/315] SPARK-3639 | Removed settings master in examples This patch removes setting of master as local in Kinesis examples so that users can set it using submit-job. Author: aniketbhatnagar Closes #2536 from aniketbhatnagar/Kinesis-Examples-Master-Unset and squashes the following commits: c9723ac [aniketbhatnagar] Merge remote-tracking branch 'origin/Kinesis-Examples-Master-Unset' into Kinesis-Examples-Master-Unset fec8ead [aniketbhatnagar] SPARK-3639 | Removed settings master in examples 31cdc59 [aniketbhatnagar] SPARK-3639 | Removed settings master in examples --- .../examples/streaming/JavaKinesisWordCountASL.java | 9 ++++----- .../examples/streaming/KinesisWordCountASL.scala | 13 +++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index aa917d0575c4c..b0bff27a61c19 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -71,6 +71,9 @@ * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ * https://kinesis.us-east-1.amazonaws.com * + * Note that number of workers/threads should be 1 more than the number of receivers. + * This leaves one thread available for actually processing the data. + * * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data * onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. @@ -114,12 +117,8 @@ public static void main(String[] args) { /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ int numStreams = numShards; - /* Must add 1 more thread than the number of receivers or the output won't show properly from the driver */ - int numSparkThreads = numStreams + 1; - /* Setup the Spark config. */ - SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount").setMaster( - "local[" + numSparkThreads + "]"); + SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount"); /* Kinesis checkpoint interval. Same as batchInterval for this example. */ Duration checkpointInterval = batchInterval; diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index fffd90de08240..32da0858d1a1d 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -65,6 +65,10 @@ import org.apache.log4j.Level * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ * https://kinesis.us-east-1.amazonaws.com * + * + * Note that number of workers/threads should be 1 more than the number of receivers. + * This leaves one thread available for actually processing the data. + * * There is a companion helper class below called KinesisWordCountProducerASL which puts * dummy data onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. @@ -97,17 +101,10 @@ private object KinesisWordCountASL extends Logging { /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ val numStreams = numShards - /* - * numSparkThreads should be 1 more thread than the number of receivers. - * This leaves one thread available for actually processing the data. - */ - val numSparkThreads = numStreams + 1 - /* Setup the and SparkConfig and StreamingContext */ /* Spark Streaming batch interval */ - val batchInterval = Milliseconds(2000) + val batchInterval = Milliseconds(2000) val sparkConfig = new SparkConf().setAppName("KinesisWordCount") - .setMaster(s"local[$numSparkThreads]") val ssc = new StreamingContext(sparkConfig, batchInterval) /* Kinesis checkpoint interval. Same as batchInterval for this example. */ From ec9df6a765701fa41390083df12e1dc1fee50662 Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Fri, 26 Sep 2014 09:58:47 -0700 Subject: [PATCH 091/315] [SPARK-3614][MLLIB] Add minimumOccurence filtering to IDF This PR for [SPARK-3614](https://issues.apache.org/jira/browse/SPARK-3614) adds functionality for filtering out terms which do not appear in at least a minimum number of documents. This is implemented using a minimumOccurence parameter (default 0). When terms' document frequencies are less than minimumOccurence, their IDFs are set to 0, just like when the DF is 0. As a result, the TF-IDFs for the terms are found to be 0, as if the terms were not present in the documents. This PR makes the following changes: * Add a minimumOccurence parameter to the IDF and DocumentFrequencyAggregator classes. * Create a parameter-less constructor for IDF with a default minimumOccurence value of 0 to remain backwards-compatibility with the original IDF API. * Sets the IDFs to 0 for terms which DFs are less than minimumOccurence * Add tests to the Spark IDFSuite and Java JavaTfIdfSuite test suites * Updated the MLLib Feature Extraction programming guide to describe the new feature Author: RJ Nowling Closes #2494 from rnowling/spark-3614-idf-filter and squashes the following commits: 0aa3c63 [RJ Nowling] Fix identation e6523a8 [RJ Nowling] Remove unnecessary toDouble's from IDFSuite bfa82ec [RJ Nowling] Add space after if 30d20b3 [RJ Nowling] Add spaces around equals signs 9013447 [RJ Nowling] Add space before division operator 79978fc [RJ Nowling] Remove unnecessary semi-colon 40fd70c [RJ Nowling] Change minimumOccurence to minDocFreq in code and docs 47850ab [RJ Nowling] Changed minimumOccurence to Int from Long 9fb4093 [RJ Nowling] Remove unnecessary lines from IDF class docs 1fc09d8 [RJ Nowling] Add backwards-compatible constructor to DocumentFrequencyAggregator 1801fd2 [RJ Nowling] Fix style errors in IDF.scala 6897252 [RJ Nowling] Preface minimumOccurence members with val to make them final and immutable a200bab [RJ Nowling] Remove unnecessary else statement 4b974f5 [RJ Nowling] Remove accidentally-added import from testing c0cc643 [RJ Nowling] Add minimumOccurence filtering to IDF --- docs/mllib-feature-extraction.md | 15 ++++++++ .../org/apache/spark/mllib/feature/IDF.scala | 37 +++++++++++++++++-- .../spark/mllib/feature/JavaTfIdfSuite.java | 20 ++++++++++ .../apache/spark/mllib/feature/IDFSuite.scala | 36 +++++++++++++++++- 4 files changed, 103 insertions(+), 5 deletions(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 41a27f6208d1b..1511ae6dda4ed 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -82,6 +82,21 @@ tf.cache() val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} + +MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature +can be used by passing the `minDocFreq` value to the IDF constructor. + +{% highlight scala %} +import org.apache.spark.mllib.feature.IDF + +// ... continue from the previous example +tf.cache() +val idf = new IDF(minDocFreq = 2).fit(tf) +val tfidf: RDD[Vector] = idf.transform(tf) +{% endhighlight %} + + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index d40d5553c1d21..720bb70b08dbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -30,9 +30,18 @@ import org.apache.spark.rdd.RDD * Inverse document frequency (IDF). * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total * number of documents and `d(t)` is the number of documents that contain term `t`. + * + * This implementation supports filtering out terms which do not appear in a minimum number + * of documents (controlled by the variable `minDocFreq`). For terms that are not in + * at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + * + * @param minDocFreq minimum of documents in which a term + * should appear for filtering */ @Experimental -class IDF { +class IDF(val minDocFreq: Int) { + + def this() = this(0) // TODO: Allow different IDF formulations. @@ -41,7 +50,8 @@ class IDF { * @param dataset an RDD of term frequency vectors */ def fit(dataset: RDD[Vector]): IDFModel = { - val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator( + minDocFreq = minDocFreq))( seqOp = (df, v) => df.add(v), combOp = (df1, df2) => df1.merge(df2) ).idf() @@ -60,13 +70,16 @@ class IDF { private object IDF { /** Document frequency aggregator. */ - class DocumentFrequencyAggregator extends Serializable { + class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable { /** number of documents */ private var m = 0L /** document frequency vector */ private var df: BDV[Long] = _ + + def this() = this(0) + /** Adds a new document. */ def add(doc: Vector): this.type = { if (isEmpty) { @@ -123,7 +136,18 @@ private object IDF { val inv = new Array[Double](n) var j = 0 while (j < n) { - inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) + /* + * If the term is not present in the minimum + * number of documents, set IDF to 0. This + * will cause multiplication in IDFModel to + * set TF-IDF to 0. + * + * Since arrays are initialized to 0 by default, + * we just omit changing those entries. + */ + if(df(j) >= minDocFreq) { + inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) + } j += 1 } Vectors.dense(inv) @@ -140,6 +164,11 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. + * + * If `minDocFreq` was set for the IDF calculation, + * the terms which occur in fewer than `minDocFreq` + * documents will have an entry of 0. + * * @param dataset an RDD of term frequency vectors * @return an RDD of TF-IDF vectors */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index e8d99f4ae43ae..064263e02cd11 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -63,4 +63,24 @@ public void tfIdf() { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } + + @Test + public void tfIdfMinimumDocumentFrequency() { + // The tests are to check Java compatibility. + HashingTF tf = new HashingTF(); + JavaRDD> documents = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("this is a sentence".split(" ")), + Lists.newArrayList("this is another sentence".split(" ")), + Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD termFreqs = tf.transform(documents); + termFreqs.collect(); + IDF idf = new IDF(2); + JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); + List localTfIdfs = tfIdfs.collect(); + int indexOfThis = tf.indexOf("this"); + for (Vector v: localTfIdfs) { + Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 53d9c0c640b98..43974f84e3ca8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -38,7 +38,7 @@ class IDFSuite extends FunSuite with LocalSparkContext { val idf = new IDF val model = idf.fit(termFrequencies) val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => - math.log((m.toDouble + 1.0) / (x + 1.0)) + math.log((m + 1.0) / (x + 1.0)) }) assert(model.idf ~== expected absTol 1e-12) val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() @@ -54,4 +54,38 @@ class IDFSuite extends FunSuite with LocalSparkContext { assert(tfidf2.indices === Array(1)) assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) } + + test("idf minimum document frequency filtering") { + val n = 4 + val localTermFrequencies = Seq( + Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(n, Array(1), Array(1.0)) + ) + val m = localTermFrequencies.size + val termFrequencies = sc.parallelize(localTermFrequencies, 2) + val idf = new IDF(minDocFreq = 1) + val model = idf.fit(termFrequencies) + val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) { + math.log((m + 1.0) / (x + 1.0)) + } else { + 0 + } + }) + assert(model.idf ~== expected absTol 1e-12) + val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(tfidf.size === 3) + val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] + assert(tfidf0.indices === Array(1, 3)) + assert(Vectors.dense(tfidf0.values) ~== + Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12) + val tfidf1 = tfidf(1L).asInstanceOf[DenseVector] + assert(Vectors.dense(tfidf1.values) ~== + Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12) + val tfidf2 = tfidf(2L).asInstanceOf[SparseVector] + assert(tfidf2.indices === Array(1)) + assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) + } + } From 30461c6ac3dcfb05dc1891494ec161601c0fb59f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 26 Sep 2014 11:26:53 -0700 Subject: [PATCH 092/315] [SPARK-3695]shuffle fetch fail output should output detailed host and port in error message Author: Daoyuan Wang Closes #2539 from adrian-wang/fetchfail and squashes the following commits: 6c1b1e0 [Daoyuan Wang] shuffle fetch fail output --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d868758a7f549..71b276b5f18e4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -121,7 +121,7 @@ final class ShuffleBlockFetcherIterator( } override def onBlockFetchFailure(e: Throwable): Unit = { - logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) // Note that there is a chance that some blocks have been fetched successfully, but we // still add them to the failed queue. This is fine because when the caller see a // FetchFailedException, it is going to fail the entire task anyway. From 8da10bf14660f1d5b1dab692cb56b9832ab10d40 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 26 Sep 2014 11:50:48 -0700 Subject: [PATCH 093/315] [SPARK-3476] Remove outdated memory checks in Yarn See description in [JIRA](https://issues.apache.org/jira/browse/SPARK-3476). Author: Andrew Or Closes #2528 from andrewor14/yarn-memory-checks and squashes the following commits: c5400cd [Andrew Or] Simplify checks e30ffac [Andrew Or] Remove outdated memory checks --- .../apache/spark/deploy/yarn/ClientArguments.scala | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 201b742736c6e..26dbd6237c6b8 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -69,16 +69,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) * This is intended to be called only after the provided arguments have been parsed. */ private def validateArgs(): Unit = { - // TODO: memory checks are outdated (SPARK-3476) - Map[Boolean, String]( - (numExecutors <= 0) -> "You must specify at least 1 executor!", - (amMemory <= amMemoryOverhead) -> s"AM memory must be > $amMemoryOverhead MB", - (executorMemory <= executorMemoryOverhead) -> - s"Executor memory must be > $executorMemoryOverhead MB" - ).foreach { case (errorCondition, errorMessage) => - if (errorCondition) { - throw new IllegalArgumentException(errorMessage + "\n" + getUsageMessage()) - } + if (numExecutors <= 0) { + throw new IllegalArgumentException( + "You must specify at least 1 executor!\n" + getUsageMessage()) } } From 0ec2d2e8f0c0dc61a7ed6377898846661d2424cd Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 26 Sep 2014 12:04:37 -0700 Subject: [PATCH 094/315] [SPARK-3531][SQL]select null from table would throw a MatchError Author: Daoyuan Wang Closes #2396 from adrian-wang/selectnull and squashes the following commits: 2458229 [Daoyuan Wang] rebase solution --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 1 + .../select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 | 1 + .../org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 3 +++ 3 files changed, 5 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b4399e852c7b..9a0b9b46ac4ee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -244,6 +244,7 @@ object HiveMetastoreTypes extends RegexParsers { case BooleanType => "boolean" case DecimalType => "decimal" case TimestampType => "timestamp" + case NullType => "void" } } diff --git a/sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 b/sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 426f5fcee6157..2f876cafaf218 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -330,6 +330,9 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + createQueryTest("select null from table", + "SELECT null FROM src LIMIT 1") + test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") .map { case Row(i: Int) => i } From 7364fa5a176da69e425bca0e3e137ee73275c78c Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 26 Sep 2014 12:06:01 -0700 Subject: [PATCH 095/315] [SPARK-3393] [SQL] Align the log4j configuration for Spark & SparkSQLCLI User may be confused for the HQL logging & configurations, we'd better provide a default templates. Both files are copied from Hive. Author: Cheng Hao Closes #2263 from chenghao-intel/hive_template and squashes the following commits: 53bffa9 [Cheng Hao] Remove the hive-log4j.properties initialization --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index b092f42372171..7ba4564602ecd 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -73,18 +73,6 @@ private[hive] object SparkSQLCLIDriver { System.exit(1) } - // NOTE: It is critical to do this here so that log4j is reinitialized - // before any of the other core hive classes are loaded - var logInitFailed = false - var logInitDetailMessage: String = null - try { - logInitDetailMessage = LogUtils.initHiveLog4j() - } catch { - case e: LogInitializationException => - logInitFailed = true - logInitDetailMessage = e.getMessage - } - val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) sessionState.in = System.in @@ -100,11 +88,6 @@ private[hive] object SparkSQLCLIDriver { System.exit(2) } - if (!sessionState.getIsSilent) { - if (logInitFailed) System.err.println(logInitDetailMessage) - else SessionState.getConsole.printInfo(logInitDetailMessage) - } - // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => From f872e4fb80b8429800daa9c44c0cac620c1ff303 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 26 Sep 2014 14:47:14 -0700 Subject: [PATCH 096/315] Revert "[SPARK-3478] [PySpark] Profile the Python tasks" This reverts commit 1aa549ba9839565274a12c52fa1075b424f138a6. --- docs/configuration.md | 19 ----------------- python/pyspark/accumulators.py | 15 ------------- python/pyspark/context.py | 39 +--------------------------------- python/pyspark/rdd.py | 10 ++------- python/pyspark/sql.py | 2 +- python/pyspark/tests.py | 30 -------------------------- python/pyspark/worker.py | 19 +++-------------- 7 files changed, 7 insertions(+), 127 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 791b6f2aa3261..a6dd7245e1552 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,25 +206,6 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. - - - - - - - - - diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index b8cdbbe3cf2b6..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,21 +215,6 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -class PStatsParam(AccumulatorParam): - """PStatsParam is used to merge pstats.Stats""" - - @staticmethod - def zero(value): - return None - - @staticmethod - def addInPlace(value1, value2): - if value1 is None: - return value2 - value1.add(value2) - return value1 - - class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index abeda19b77d8b..8e7b00469e246 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -31,6 +30,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel +from pyspark import rdd from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -192,9 +192,6 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() - # profiling stats collected for each PythonRDD - self._profile_stats = [] - def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization @@ -795,40 +792,6 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) - - def show_profiles(self): - """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 - stats.sort_stats("tottime", "cumtime").print_stats() - # mark it as showed - self._profile_stats[i][2] = True - - def dump_profiles(self, path): - """ Dump the profile stats into directory `path` - """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] - def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ed89e2f9769f..680140d72d03c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -15,6 +15,7 @@ # limitations under the License. # +from base64 import standard_b64encode as b64enc import copy from collections import defaultdict from itertools import chain, ifilter, imap @@ -31,7 +32,6 @@ from random import Random from math import sqrt, log, isinf, isnan -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -2080,9 +2080,7 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2104,10 +2102,6 @@ def _jrdd(self): self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - - if enable_profile: - self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) return self._jrdd_val def id(self): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ee5bda8bb43d5..653195ea438cf 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -974,7 +974,7 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, + command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e6002afa9c70d..d1bb2033b7a16 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -632,36 +632,6 @@ def test_distinct(self): self.assertEquals(result.count(), 3) -class TestProfiler(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) - - def test_profiler(self): - - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] - stats = acc.value - self.assertTrue(stats is not None) - width, stat_list = stats.get_print_list([]) - func_names = [func_name for fname, n, func_name in stat_list] - self.assertTrue("heavy_foo" in func_names) - - self.sc.show_profiles() - d = tempfile.gettempdir() - self.sc.dump_profiles(d) - self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) - - class TestSQL(PySparkTestCase): def setUp(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8257dddfee1c3..c1f6e3e4a1f40 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,8 +23,6 @@ import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -92,21 +90,10 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, deserializer, serializer) = command init_time = time.time() - - def process(): - iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) - - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) - else: - process() + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) From 5e34855cf04145cc3b7bae996c2a6e668f144a11 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 26 Sep 2014 21:29:54 -0700 Subject: [PATCH 097/315] [SPARK-3543] Write TaskContext in Java and expose it through a static accessor. Author: Prashant Sharma Author: Shashank Sharma Closes #2425 from ScrapCodes/SPARK-3543/withTaskContext and squashes the following commits: 8ae414c [Shashank Sharma] CR ee8bd00 [Prashant Sharma] Added internal API in docs comments. ddb8cbe [Prashant Sharma] Moved setting the thread local to where TaskContext is instantiated. a7d5e23 [Prashant Sharma] Added doc comments. edf945e [Prashant Sharma] Code review git add -A f716fd1 [Prashant Sharma] introduced thread local for getting the task context. 333c7d6 [Prashant Sharma] Translated Task context from scala to java. --- .../java/org/apache/spark/TaskContext.java | 274 ++++++++++++++++++ .../scala/org/apache/spark/TaskContext.scala | 126 -------- .../main/scala/org/apache/spark/rdd/RDD.scala | 1 + .../apache/spark/scheduler/DAGScheduler.scala | 4 +- .../org/apache/spark/scheduler/Task.scala | 6 +- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../org/apache/spark/CacheManagerSuite.scala | 2 +- 7 files changed, 284 insertions(+), 131 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/TaskContext.java delete mode 100644 core/src/main/scala/org/apache/spark/TaskContext.scala diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java new file mode 100644 index 0000000000000..09b8ce02bd3d8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import scala.Function0; +import scala.Function1; +import scala.Unit; +import scala.collection.JavaConversions; + +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskCompletionListenerException; + +/** +* :: DeveloperApi :: +* Contextual information about a task which can be read or mutated during execution. +*/ +@DeveloperApi +public class TaskContext implements Serializable { + + private int stageId; + private int partitionId; + private long attemptId; + private boolean runningLocally; + private TaskMetrics taskMetrics; + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, + TaskMetrics taskMetrics) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = taskMetrics; + } + + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, + Boolean runningLocally) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = false; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + private static ThreadLocal taskContext = + new ThreadLocal(); + + /** + * :: Internal API :: + * This is spark internal API, not intended to be called from user programs. + */ + public static void setTaskContext(TaskContext tc) { + taskContext.set(tc); + } + + public static TaskContext get() { + return taskContext.get(); + } + + /** + * :: Internal API :: + */ + public static void remove() { + taskContext.remove(); + } + + // List of callback functions to execute when the task completes. + private transient List onCompleteCallbacks = + new ArrayList(); + + // Whether the corresponding task has been killed. + private volatile Boolean interrupted = false; + + // Whether the task has completed. + private volatile Boolean completed = false; + + /** + * Checks whether the task has completed. + */ + public Boolean isCompleted() { + return completed; + } + + /** + * Checks whether the task has been killed. + */ + public Boolean isInterrupted() { + return interrupted; + } + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + *

    + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { + onCompleteCallbacks.add(listener); + return this; + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situations - success, failure, or cancellation. + *

    + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(final Function1 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(context); + } + }); + return this; + } + + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * + * Deprecated: use addTaskCompletionListener + * + * @param f Callback function. + */ + @Deprecated + public void addOnCompleteCallback(final Function0 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(); + } + }); + } + + /** + * ::Internal API:: + * Marks the task as completed and triggers the listeners. + */ + public void markTaskCompleted() throws TaskCompletionListenerException { + completed = true; + List errorMsgs = new ArrayList(2); + // Process complete callbacks in the reverse order of registration + List revlist = + new ArrayList(onCompleteCallbacks); + Collections.reverse(revlist); + for (TaskCompletionListener tcl: revlist) { + try { + tcl.onTaskCompletion(this); + } catch (Throwable e) { + errorMsgs.add(e.getMessage()); + } + } + + if (!errorMsgs.isEmpty()) { + throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); + } + } + + /** + * ::Internal API:: + * Marks the task for interruption, i.e. cancellation. + */ + public void markInterrupted() { + interrupted = true; + } + + @Deprecated + /** Deprecated: use getStageId() */ + public int stageId() { + return stageId; + } + + @Deprecated + /** Deprecated: use getPartitionId() */ + public int partitionId() { + return partitionId; + } + + @Deprecated + /** Deprecated: use getAttemptId() */ + public long attemptId() { + return attemptId; + } + + @Deprecated + /** Deprecated: use getRunningLocally() */ + public boolean runningLocally() { + return runningLocally; + } + + public boolean getRunningLocally() { + return runningLocally; + } + + public int getStageId() { + return stageId; + } + + public int getPartitionId() { + return partitionId; + } + + public long getAttemptId() { + return attemptId; + } + + /** ::Internal API:: */ + public TaskMetrics taskMetrics() { + return taskMetrics; + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala deleted file mode 100644 index 51b3e4d5e0936..0000000000000 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} - - -/** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ -@DeveloperApi -class TaskContext( - val stageId: Int, - val partitionId: Int, - val attemptId: Long, - val runningLocally: Boolean = false, - private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable with Logging { - - @deprecated("use partitionId", "0.8.1") - def splitId = partitionId - - // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] - - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false - - // Whether the task has completed. - @volatile private var completed: Boolean = false - - /** Checks whether the task has completed. */ - def isCompleted: Boolean = completed - - /** Checks whether the task has been killed. */ - def isInterrupted: Boolean = interrupted - - // TODO: Also track whether the task has completed successfully or with exception. - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener - this - } - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } - this - } - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * @param f Callback function. - */ - @deprecated("use addTaskCompletionListener", "1.1.0") - def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f() - } - } - - /** Marks the task as completed and triggers the listeners. */ - private[spark] def markTaskCompleted(): Unit = { - completed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => - try { - listener.onTaskCompletion(this) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) - } - } - - /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0e90caa5c9ca7..ba712c9d7776f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ @DeveloperApi + @deprecated("use TaskContext.get", "1.2.0") def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b2774dfc47553..32cf29ed140e6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -634,12 +634,14 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true) + new TaskContext(job.finalStage.id, job.partitions(0), 0, true) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() + TaskContext.remove() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6aa0cca06878d..bf73f6f7bd0e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context = new TaskContext(stageId, partitionId, attemptId, false) + TaskContext.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + TaskContext.remove() + } } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8c23d524e00b..4a078435447e5 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -776,7 +776,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics()); + TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 90dcadcffd091..d735010d7c9d5 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true) + val context = new TaskContext(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } From a3feaf04dc35069b80233fe7cccd62fc3072fc1f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 26 Sep 2014 21:44:10 -0700 Subject: [PATCH 098/315] Close #2194. From e976ca236f3c5578d8d7639b788774b1053b65f7 Mon Sep 17 00:00:00 2001 From: Sarah Gerweck Date: Fri, 26 Sep 2014 22:21:50 -0700 Subject: [PATCH 099/315] Slaves file is now a template. Change 0dc868e removed the `conf/slaves` file and made it a template like most of the other configuration files. This means you can no longer run `make-distribution.sh` unless you manually create a slaves file to be statically bundled in your distribution, which seems at odds with making it a template file. Author: Sarah Gerweck Closes #2549 from sarahgerweck/noMoreSlaves and squashes the following commits: d11d99a [Sarah Gerweck] Slaves file is now a template. --- make-distribution.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/make-distribution.sh b/make-distribution.sh index 884659954a491..0bc839e1dbe4d 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -201,7 +201,6 @@ fi # Copy other things mkdir "$DISTDIR"/conf cp "$FWDIR"/conf/*.template "$DISTDIR"/conf -cp "$FWDIR"/conf/slaves "$DISTDIR"/conf cp "$FWDIR/README.md" "$DISTDIR" cp -r "$FWDIR/bin" "$DISTDIR" cp -r "$FWDIR/python" "$DISTDIR" From 0cdcdd2c9df98fb64d9d16ebace992fbba9c16b4 Mon Sep 17 00:00:00 2001 From: wangfei Date: Fri, 26 Sep 2014 22:23:49 -0700 Subject: [PATCH 100/315] [Build]remove spark-staging-1030 Since 1.1.0 has published, remove spark-staging-1030. Author: wangfei Closes #2532 from scwf/patch-2 and squashes the following commits: bc9e00b [wangfei] remove spark-staging-1030 --- pom.xml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pom.xml b/pom.xml index f3de097b9cb32..70cb9729ff6d3 100644 --- a/pom.xml +++ b/pom.xml @@ -222,18 +222,6 @@ false - - - spark-staging-1030 - Spark 1.1.0 Staging (1030) - https://repository.apache.org/content/repositories/orgapachespark-1030/ - - true - - - false - - From f0eea76d941c487763febbd9162600f89cedbd5c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 26 Sep 2014 22:24:34 -0700 Subject: [PATCH 101/315] [SQL][DOCS] Clarify that the server is for JDBC and ODBC Author: Michael Armbrust Closes #2527 from marmbrus/patch-1 and squashes the following commits: a0f9f1c [Michael Armbrust] [SQL][DOCS] Clarify that the server is for JDBC and ODBC --- docs/sql-programming-guide.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c1f80544bf0af..65249808fae3e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -872,12 +872,12 @@ that these options will be deprecated in future release as more optimizations ar Spark SQL also supports interfaces for running SQL queries directly without the need to write any code. -## Running the Thrift JDBC server +## Running the Thrift JDBC/ODBC server -The Thrift JDBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) +The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.12. -To start the JDBC server, run the following in the Spark directory: +To start the JDBC/ODBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh @@ -906,11 +906,11 @@ or system properties: ``` {% endhighlight %} -Now you can use beeline to test the Thrift JDBC server: +Now you can use beeline to test the Thrift JDBC/ODBC server: ./bin/beeline -Connect to the JDBC server in beeline with: +Connect to the JDBC/ODBC server in beeline with: beeline> !connect jdbc:hive2://localhost:10000 From d8a9d1d442dd5612f82edaf2a780579c4d43dcfd Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 26 Sep 2014 22:30:12 -0700 Subject: [PATCH 102/315] [SPARK-3675][SQL] Allow starting a JDBC server on an existing context Author: Michael Armbrust Closes #2515 from marmbrus/jdbcExistingContext and squashes the following commits: 7866fad [Michael Armbrust] Allows starting a JDBC server on an existing context. --- .../sql/hive/thriftserver/HiveThriftServer2.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index cadf7aaf42157..3d468d804622c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -26,6 +26,7 @@ import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -33,9 +34,21 @@ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a * `HiveThriftServer2` thrift server. */ -private[hive] object HiveThriftServer2 extends Logging { +object HiveThriftServer2 extends Logging { var LOG = LogFactory.getLog(classOf[HiveServer2]) + /** + * :: DeveloperApi :: + * Starts a new thrift server with the given context. + */ + @DeveloperApi + def startWithContext(sqlContext: HiveContext): Unit = { + val server = new HiveThriftServer2(sqlContext) + server.init(sqlContext.hiveconf) + server.start() + } + + def main(args: Array[String]) { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") From 9e8ced7847d84d63f0da08b15623d558a2407583 Mon Sep 17 00:00:00 2001 From: Jeff Steinmetz Date: Fri, 26 Sep 2014 23:00:40 -0700 Subject: [PATCH 103/315] stop, start and destroy require the EC2_REGION i.e ./spark-ec2 --region=us-west-1 stop yourclustername Author: Jeff Steinmetz Closes #2473 from jeffsteinmetz/master and squashes the following commits: 7491f2c [Jeff Steinmetz] fix case in EC2 cluster setup documentation bd3d777 [Jeff Steinmetz] standardized ec2 documenation to use sample args 2bf4a57 [Jeff Steinmetz] standardized ec2 documenation to use sample args 68d8372 [Jeff Steinmetz] standardized ec2 documenation to use sample args d2ab6e2 [Jeff Steinmetz] standardized ec2 documenation to use sample args 520e6dc [Jeff Steinmetz] standardized ec2 documenation to use sample args 37fc876 [Jeff Steinmetz] stop, start and destroy require the EC2_REGION --- docs/ec2-scripts.md | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index b2ca6a9b48f32..530798f2b8022 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -48,6 +48,15 @@ by looking for the "Name" tag of the instance in the Amazon EC2 Console. key pair, `` is the number of slave nodes to launch (try 1 at first), and `` is the name to give to your cluster. + + For example: + + ```bash + export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU +export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 +./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --spark-version=1.1.0 launch my-spark-cluster + ``` + - After everything launches, check that the cluster scheduler is up and sees all the slaves by going to its web UI, which will be printed at the end of the script (typically `http://:8080`). @@ -55,27 +64,27 @@ by looking for the "Name" tag of the instance in the Amazon EC2 Console. You can also run `./spark-ec2 --help` to see more usage options. The following options are worth pointing out: -- `--instance-type=` can be used to specify an EC2 +- `--instance-type=` can be used to specify an EC2 instance type to use. For now, the script only supports 64-bit instance types, and the default type is `m1.large` (which has 2 cores and 7.5 GB RAM). Refer to the Amazon pages about [EC2 instance types](http://aws.amazon.com/ec2/instance-types) and [EC2 pricing](http://aws.amazon.com/ec2/#pricing) for information about other instance types. -- `--region=` specifies an EC2 region in which to launch +- `--region=` specifies an EC2 region in which to launch instances. The default region is `us-east-1`. -- `--zone=` can be used to specify an EC2 availability zone +- `--zone=` can be used to specify an EC2 availability zone to launch instances in. Sometimes, you will get an error because there is not enough capacity in one zone, and you should try to launch in another. -- `--ebs-vol-size=GB` will attach an EBS volume with a given amount +- `--ebs-vol-size=` will attach an EBS volume with a given amount of space to each node so that you can have a persistent HDFS cluster on your nodes across cluster restarts (see below). -- `--spot-price=PRICE` will launch the worker nodes as +- `--spot-price=` will launch the worker nodes as [Spot Instances](http://aws.amazon.com/ec2/spot-instances/), bidding for the given maximum price (in dollars). -- `--spark-version=VERSION` will pre-load the cluster with the - specified version of Spark. VERSION can be a version number +- `--spark-version=` will pre-load the cluster with the + specified version of Spark. The `` can be a version number (e.g. "0.7.3") or a specific git hash. By default, a recent version will be used. - If one of your launches fails due to e.g. not having the right @@ -137,11 +146,11 @@ cost you any EC2 cycles, but ***will*** continue to cost money for EBS storage. - To stop one of your clusters, go into the `ec2` directory and run -`./spark-ec2 stop `. +`./spark-ec2 --region= stop `. - To restart it later, run -`./spark-ec2 -i start `. +`./spark-ec2 -i --region= start `. - To ultimately destroy the cluster and stop consuming EBS space, run -`./spark-ec2 destroy ` as described in the previous +`./spark-ec2 --region= destroy ` as described in the previous section. # Limitations From 2d972fd84ac54a89e416442508a6d4eaeff452c1 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Fri, 26 Sep 2014 23:15:10 -0700 Subject: [PATCH 104/315] [SPARK-1021] Defer the data-driven computation of partition bounds in so... ...rtByKey() until evaluation. Author: Erik Erlandson Closes #1689 from erikerlandson/spark-1021-pr and squashes the following commits: 50b6da6 [Erik Erlandson] use standard getIteratorSize in countAsync 4e334a9 [Erik Erlandson] exception mystery fixed by fixing bug in ComplexFutureAction b88b5d4 [Erik Erlandson] tweak async actions to use ComplexFutureAction[T] so they handle RangePartitioner sampling job properly b2b20e8 [Erik Erlandson] Fix bug in exception passing with ComplexFutureAction[T] ca8913e [Erik Erlandson] RangePartition sampling job -> FutureAction 7143f97 [Erik Erlandson] [SPARK-1021] modify range bounds variable to be thread safe ac67195 [Erik Erlandson] [SPARK-1021] Defer the data-driven computation of partition bounds in sortByKey() until evaluation. --- .../scala/org/apache/spark/FutureAction.scala | 7 +- .../scala/org/apache/spark/Partitioner.scala | 29 +++++++-- .../apache/spark/rdd/AsyncRDDActions.scala | 64 +++++++++++-------- 3 files changed, 66 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 75ea535f2f57b..c277c3a47d421 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -208,7 +208,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R) { + resultFunc: => R): R = { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { @@ -223,7 +223,10 @@ class ComplexFutureAction[T] extends FutureAction[T] { // cancel the job and stop the execution. This is not in a synchronized block because // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. try { - Await.ready(job, Duration.Inf) + Await.ready(job, Duration.Inf).value.get match { + case scala.util.Failure(e) => throw e + case scala.util.Success(v) => v + } } catch { case e: InterruptedException => job.cancel() diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 37053bb6f37ad..d40b152d221c5 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -29,6 +29,10 @@ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} +import org.apache.spark.SparkContext.rddToAsyncRDDActions +import scala.concurrent.Await +import scala.concurrent.duration.Duration + /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. @@ -113,8 +117,12 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions - private var rangeBounds: Array[K] = { - if (partitions <= 1) { + @volatile private var valRB: Array[K] = null + + private def rangeBounds: Array[K] = this.synchronized { + if (valRB != null) return valRB + + valRB = if (partitions <= 1) { Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. @@ -152,6 +160,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( RangePartitioner.determineBounds(candidates, partitions) } } + + valRB } def numPartitions = rangeBounds.length + 1 @@ -222,7 +232,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = this.synchronized { + if (valRB != null) return val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() @@ -234,7 +245,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( val ser = sfactory.newInstance() Utils.deserializeViaNestedStream(in, ser) { ds => implicit val classTag = ds.readObject[ClassTag[Array[K]]]() - rangeBounds = ds.readObject[Array[K]]() + valRB = ds.readObject[Array[K]]() } } } @@ -254,12 +265,18 @@ private[spark] object RangePartitioner { sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object - val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => + // use collectAsync here to run this job as a future, which is cancellable + val sketchFuture = rdd.mapPartitionsWithIndex { (idx, iter) => val seed = byteswap32(idx ^ (shift << 16)) val (sample, n) = SamplingUtils.reservoirSampleAndCount( iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) - }.collect() + }.collectAsync() + // We do need the future's value to continue any further + val sketched = Await.ready(sketchFuture, Duration.Inf).value.get match { + case scala.util.Success(v) => v.toArray + case scala.util.Failure(e) => throw e + } val numItems = sketched.map(_._2.toLong).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index b62f3fbdc4a15..7a68b3afa8158 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag +import org.apache.spark.util.Utils import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} import org.apache.spark.annotation.Experimental @@ -38,29 +39,30 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for counting the number of elements in the RDD. */ def countAsync(): FutureAction[Long] = { - val totalCount = new AtomicLong - self.context.submitJob( - self, - (iter: Iterator[T]) => { - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - }, - Range(0, self.partitions.size), - (index: Int, data: Long) => totalCount.addAndGet(data), - totalCount.get()) + val f = new ComplexFutureAction[Long] + f.run { + val totalCount = new AtomicLong + f.runJob(self, + (iter: Iterator[T]) => Utils.getIteratorSize(iter), + Range(0, self.partitions.size), + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) + } } /** * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), - (index, data) => results(index) = data, results.flatten.toSeq) + val f = new ComplexFutureAction[Seq[T]] + f.run { + val results = new Array[Array[T]](self.partitions.size) + f.runJob(self, + (iter: Iterator[T]) => iter.toArray, + Range(0, self.partitions.size), + (index: Int, data: Array[T]) => results(index) = data, + results.flatten.toSeq) + } } /** @@ -104,24 +106,34 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } results.toSeq } - - f } /** * Applies a function f to all elements of this RDD. */ - def foreachAsync(f: T => Unit): FutureAction[Unit] = { - val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), - (index, data) => Unit, Unit) + def foreachAsync(expr: T => Unit): FutureAction[Unit] = { + val f = new ComplexFutureAction[Unit] + val exprClean = self.context.clean(expr) + f.run { + f.runJob(self, + (iter: Iterator[T]) => iter.foreach(exprClean), + Range(0, self.partitions.size), + (index: Int, data: Unit) => Unit, + Unit) + } } /** * Applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), - (index, data) => Unit, Unit) + def foreachPartitionAsync(expr: Iterator[T] => Unit): FutureAction[Unit] = { + val f = new ComplexFutureAction[Unit] + f.run { + f.runJob(self, + expr, + Range(0, self.partitions.size), + (index: Int, data: Unit) => Unit, + Unit) + } } } From 436a7730b6e7067f74b3739a3a412490003f7c4c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 27 Sep 2014 00:57:26 -0700 Subject: [PATCH 105/315] Minor cleanup to tighten visibility and remove compilation warning. Author: Reynold Xin Closes #2555 from rxin/cleanup and squashes the following commits: 6add199 [Reynold Xin] Minor cleanup to tighten visibility and remove compilation warning. --- .../input/WholeTextFileRecordReader.scala | 24 +++++----- .../apache/spark/metrics/MetricsSystem.scala | 28 ++++++----- .../spark/metrics/MetricsSystemSuite.scala | 33 +++++++------ .../streaming/StreamingContextSuite.scala | 47 ++++++++++--------- 4 files changed, 70 insertions(+), 62 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index c3dabd2e79995..3564ab2e2a162 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -36,33 +36,31 @@ private[spark] class WholeTextFileRecordReader( index: Integer) extends RecordReader[String, String] { - private val path = split.getPath(index) - private val fs = path.getFileSystem(context.getConfiguration) + private[this] val path = split.getPath(index) + private[this] val fs = path.getFileSystem(context.getConfiguration) // True means the current file has been processed, then skip it. - private var processed = false + private[this] var processed = false - private val key = path.toString - private var value: String = null + private[this] val key = path.toString + private[this] var value: String = null - override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} - override def close() = {} + override def close(): Unit = {} - override def getProgress = if (processed) 1.0f else 0.0f + override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey = key + override def getCurrentKey: String = key - override def getCurrentValue = value + override def getCurrentValue: String = value - override def nextKeyValue = { + override def nextKeyValue(): Boolean = { if (!processed) { val fileIn = fs.open(path) val innerBuffer = ByteStreams.toByteArray(fileIn) - value = new Text(innerBuffer).toString Closeables.close(fileIn, false) - processed = true true } else { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 6ef817d0e587e..fd316a89a1a10 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -63,15 +63,18 @@ import org.apache.spark.metrics.source.Source * * [options] is the specific property of this source or sink. */ -private[spark] class MetricsSystem private (val instance: String, - conf: SparkConf, securityMgr: SecurityManager) extends Logging { +private[spark] class MetricsSystem private ( + val instance: String, + conf: SparkConf, + securityMgr: SecurityManager) + extends Logging { - val confFile = conf.get("spark.metrics.conf", null) - val metricsConfig = new MetricsConfig(Option(confFile)) + private[this] val confFile = conf.get("spark.metrics.conf", null) + private[this] val metricsConfig = new MetricsConfig(Option(confFile)) - val sinks = new mutable.ArrayBuffer[Sink] - val sources = new mutable.ArrayBuffer[Source] - val registry = new MetricRegistry() + private val sinks = new mutable.ArrayBuffer[Sink] + private val sources = new mutable.ArrayBuffer[Source] + private val registry = new MetricRegistry() // Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui private var metricsServlet: Option[MetricsServlet] = None @@ -91,7 +94,7 @@ private[spark] class MetricsSystem private (val instance: String, sinks.foreach(_.stop) } - def report(): Unit = { + def report() { sinks.foreach(_.report()) } @@ -155,8 +158,8 @@ private[spark] object MetricsSystem { val SINK_REGEX = "^sink\\.(.+)\\.(.+)".r val SOURCE_REGEX = "^source\\.(.+)\\.(.+)".r - val MINIMAL_POLL_UNIT = TimeUnit.SECONDS - val MINIMAL_POLL_PERIOD = 1 + private[this] val MINIMAL_POLL_UNIT = TimeUnit.SECONDS + private[this] val MINIMAL_POLL_PERIOD = 1 def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) { val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit) @@ -166,7 +169,8 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String, conf: SparkConf, - securityMgr: SecurityManager): MetricsSystem = + def createMetricsSystem( + instance: String, conf: SparkConf, securityMgr: SecurityManager): MetricsSystem = { new MetricsSystem(instance, conf, securityMgr) + } } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 96a5a1231813e..e42b181194727 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,42 +17,47 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.metrics.source.Source +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.MasterSource -class MetricsSystemSuite extends FunSuite with BeforeAndAfter { +import scala.collection.mutable.ArrayBuffer + + +class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null var securityMgr: SecurityManager = null before { - filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() + filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile conf = new SparkConf(false).set("spark.metrics.conf", filePath) securityMgr = new SecurityManager(conf) } test("MetricsSystem with default config") { val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks + val sources = PrivateMethod[ArrayBuffer[Source]]('sources) + val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) - assert(sources.length === 0) - assert(sinks.length === 0) - assert(!metricsSystem.getServletHandlers.isEmpty) + assert(metricsSystem.invokePrivate(sources()).length === 0) + assert(metricsSystem.invokePrivate(sinks()).length === 0) + assert(metricsSystem.getServletHandlers.nonEmpty) } test("MetricsSystem with sources add") { val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks + val sources = PrivateMethod[ArrayBuffer[Source]]('sources) + val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) - assert(sources.length === 0) - assert(sinks.length === 1) - assert(!metricsSystem.getServletHandlers.isEmpty) + assert(metricsSystem.invokePrivate(sources()).length === 0) + assert(metricsSystem.invokePrivate(sinks()).length === 1) + assert(metricsSystem.getServletHandlers.nonEmpty) val source = new MasterSource(null) metricsSystem.registerSource(source) - assert(sources.length === 1) + assert(metricsSystem.invokePrivate(sources()).length === 1) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index ebf83748ffa28..655cec1573f58 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -19,18 +19,18 @@ package org.apache.spark.streaming import java.util.concurrent.atomic.AtomicInteger -import scala.language.postfixOps +import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.Eventually._ +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} -import org.scalatest.concurrent.Timeouts -import org.scalatest.concurrent.Eventually._ -import org.scalatest.exceptions.TestFailedDueToTimeoutException -import org.scalatest.time.SpanSugar._ + class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { @@ -68,7 +68,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from no conf + spark home + env") { ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil, Map(envPair)) - assert(ssc.conf.getExecutorEnv.exists(_ == envPair)) + assert(ssc.conf.getExecutorEnv.contains(envPair)) } test("from conf with settings") { @@ -94,7 +94,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.cleaner.ttl", "10") val ssc1 = new StreamingContext(myConf, batchDuration) - addInputStream(ssc1).register + addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert(cp.sparkConfPairs.toMap.getOrElse("spark.cleaner.ttl", "-1") === "10") @@ -107,7 +107,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("start and stop state check") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() assert(ssc.state === ssc.StreamingContextState.Initialized) ssc.start() @@ -118,7 +118,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() intercept[SparkException] { ssc.start() @@ -127,7 +127,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop multiple times") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() ssc.stop() ssc.stop() @@ -135,7 +135,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop before start and start after stop") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.stop() // stop before start should not throw exception ssc.start() ssc.stop() @@ -147,12 +147,12 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop only streaming context") { ssc = new StreamingContext(master, appName, batchDuration) sc = ssc.sparkContext - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() - ssc.stop(false) + ssc.stop(stopSparkContext = false) assert(sc.makeRDD(1 to 100).collect().size === 100) ssc = new StreamingContext(sc, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() ssc.stop() } @@ -167,11 +167,11 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w var runningCount = 0 TestReceiver.counter.set(1) val input = ssc.receiverStream(new TestReceiver) - input.count.foreachRDD(rdd => { + input.count().foreachRDD { rdd => val count = rdd.first() runningCount += count.toInt logInfo("Count = " + count + ", Running count = " + runningCount) - }) + } ssc.start() ssc.awaitTermination(500) ssc.stop(stopSparkContext = false, stopGracefully = true) @@ -191,7 +191,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.map(x => x).register + inputStream.map(x => x).register() // test whether start() blocks indefinitely or not failAfter(2000 millis) { @@ -215,7 +215,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown new Thread() { - override def run { + override def run() { Thread.sleep(500) ssc.stop() } @@ -239,8 +239,9 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("awaitTermination with error in task") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.map(x => { throw new TestException("error in map task"); x}) - .foreachRDD(_.count) + inputStream + .map { x => throw new TestException("error in map task"); x } + .foreachRDD(_.count()) val exception = intercept[Exception] { ssc.start() @@ -252,7 +253,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("awaitTermination with error in job generation") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register + inputStream.transform { rdd => throw new TestException("error in transform"); rdd }.register() val exception = intercept[TestException] { ssc.start() ssc.awaitTermination(5000) @@ -265,7 +266,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } def addInputStream(s: StreamingContext): DStream[Int] = { - val input = (1 to 100).map(i => (1 to i)) + val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) inputStream } From 66107f46f374f83729cd79ab260eb59fa123c041 Mon Sep 17 00:00:00 2001 From: CrazyJvm Date: Sat, 27 Sep 2014 09:41:04 -0700 Subject: [PATCH 106/315] Docs : use "--total-executor-cores" rather than "--cores" after spark-shell Author: CrazyJvm Closes #2540 from CrazyJvm/standalone-core and squashes the following commits: 66d9fc6 [CrazyJvm] use "--total-executor-cores" rather than "--cores" after spark-shell --- docs/spark-standalone.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 58103fab20819..a3028aa86dc45 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -247,7 +247,7 @@ To run an interactive Spark shell against the cluster, run the following command ./bin/spark-shell --master spark://IP:PORT -You can also pass an option `--cores ` to control the number of cores that spark-shell uses on the cluster. +You can also pass an option `--total-executor-cores ` to control the number of cores that spark-shell uses on the cluster. # Launching Compiled Spark Applications From 0800881051df8029afb22a4ec17970e316a85855 Mon Sep 17 00:00:00 2001 From: w00228970 Date: Sat, 27 Sep 2014 12:06:06 -0700 Subject: [PATCH 107/315] [SPARK-3676][SQL] Fix hive test suite failure due to diffs in JDK 1.6/1.7 This is a bug in JDK6: http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4428022 this is because jdk get different result to operate ```double```, ```System.out.println(1/500d)``` in different jdk get different result jdk 1.6.0(_31) ---- 0.0020 jdk 1.7.0(_05) ---- 0.002 this leads to HiveQuerySuite failed when generate golden answer in jdk 1.7 and run tests in jdk 1.6, result did not match Author: w00228970 Closes #2517 from scwf/HiveQuerySuite and squashes the following commits: 0cb5e8d [w00228970] delete golden answer of division-0 and timestamp cast #1 1df3964 [w00228970] Jdk version leads to different query output for Double, this make HiveQuerySuite failed --- .../division-0-63b19f8a22471c8ba0415c1d3bc276f7 | 1 - ...amp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 | 1 - .../spark/sql/hive/execution/HiveQuerySuite.scala | 15 +++++++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) delete mode 100644 sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 diff --git a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 deleted file mode 100644 index 7b7a9175114ce..0000000000000 --- a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 +++ /dev/null @@ -1 +0,0 @@ -2.0 0.5 0.3333333333333333 0.002 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 b/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 deleted file mode 100644 index 8ebf695ba7d20..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 +++ /dev/null @@ -1 +0,0 @@ -0.001 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2f876cafaf218..2da8a6fac3d99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -135,8 +135,12 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("div", "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1") - createQueryTest("division", - "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") + // Jdk version leads to different query output for double, so not use createQueryTest here + test("division") { + val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head + Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res).foreach( x => + assert(x._1 == x._2.asInstanceOf[Double])) + } createQueryTest("modulus", "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") @@ -306,8 +310,11 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("case statements WITHOUT key #4", "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") - createQueryTest("timestamp cast #1", - "SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + // Jdk version leads to different query output for double, so not use createQueryTest here + test("timestamp cast #1") { + val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(0.001 == res.getDouble(0)) + } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") From f0c7e19550d46f81a0a3ff272bbf66ce4bafead6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 27 Sep 2014 12:10:16 -0700 Subject: [PATCH 108/315] [SPARK-3680][SQL] Fix bug caused by eager typing of HiveGenericUDFs Typing of UDFs should be lazy as it is often not valid to call `dataType` on an expression until after all of its children are `resolved`. Author: Michael Armbrust Closes #2525 from marmbrus/concatBug and squashes the following commits: 5b8efe7 [Michael Armbrust] fix bug with eager typing of udfs --- .../org/apache/spark/sql/hive/hiveUdfs.scala | 2 +- .../spark/sql/parquet/ParquetMetastoreSuite.scala | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 68944ed4ef21d..732e4976f6843 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -151,7 +151,7 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq override def get(): AnyRef = wrap(func()) } - val dataType: DataType = inspectorToDataType(returnInspector) + lazy val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: Row): Any = { returnInspector // Make sure initialized. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala index e380280f301c1..86adbbf3ad2d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet import java.io.File +import org.apache.spark.sql.catalyst.expressions.Row import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest @@ -142,15 +143,21 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { test("sum") { checkAnswer( sql("SELECT SUM(intField) FROM partitioned_parquet WHERE intField IN (1,2,3) AND p = 1"), - 1 + 2 + 3 - ) + 1 + 2 + 3) + } + + test("hive udfs") { + checkAnswer( + sql("SELECT concat(stringField, stringField) FROM partitioned_parquet"), + sql("SELECT stringField FROM partitioned_parquet").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) } test("non-part select(*)") { checkAnswer( sql("SELECT COUNT(*) FROM normal_parquet"), - 10 - ) + 10) } test("conversion is working") { From 0d8cdf0ede908f6c488a075170f1563815009e29 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 12:21:37 -0700 Subject: [PATCH 109/315] [SPARK-3681] [SQL] [PySpark] fix serialization of List and Map in SchemaRDD Currently, the schema of object in ArrayType or MapType is attached lazily, it will have better performance but introduce issues while serialization or accessing nested objects. This patch will apply schema to the objects of ArrayType or MapType immediately when accessing them, will be a little bit slower, but much robust. Author: Davies Liu Closes #2526 from davies/nested and squashes the following commits: 2399ae5 [Davies Liu] fix serialization of List and Map in SchemaRDD --- python/pyspark/sql.py | 40 +++++++++++++--------------------------- python/pyspark/tests.py | 21 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 653195ea438cf..f71d24c470dc9 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -838,43 +838,29 @@ def _create_cls(dataType): >>> obj = _create_cls(schema)(row) >>> pickle.loads(pickle.dumps(obj)) Row(a=[1], b={'key': Row(c=1, d=2.0)}) + >>> pickle.loads(pickle.dumps(obj.a)) + [1] + >>> pickle.loads(pickle.dumps(obj.b)) + {'key': Row(c=1, d=2.0)} """ if isinstance(dataType, ArrayType): cls = _create_cls(dataType.elementType) - class List(list): - - def __getitem__(self, i): - # create object with datetype - return _create_object(cls, list.__getitem__(self, i)) - - def __repr__(self): - # call collect __repr__ for nested objects - return "[%s]" % (", ".join(repr(self[i]) - for i in range(len(self)))) - - def __reduce__(self): - return list.__reduce__(self) + def List(l): + if l is None: + return + return [_create_object(cls, v) for v in l] return List elif isinstance(dataType, MapType): - vcls = _create_cls(dataType.valueType) - - class Dict(dict): - - def __getitem__(self, k): - # create object with datetype - return _create_object(vcls, dict.__getitem__(self, k)) - - def __repr__(self): - # call collect __repr__ for nested objects - return "{%s}" % (", ".join("%r: %r" % (k, self[k]) - for k in self)) + cls = _create_cls(dataType.valueType) - def __reduce__(self): - return dict.__reduce__(self) + def Dict(d): + if d is None: + return + return dict((k, _create_object(cls, v)) for k, v in d.items()) return Dict diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index d1bb2033b7a16..29df754c6fd29 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -698,6 +698,27 @@ def test_apply_schema_to_row(self): srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) self.assertEqual(10, srdd3.count()) + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + row = srdd.first() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = srdd.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = srdd.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = srdd.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + class TestIO(PySparkTestCase): From 5b922bb458e863f5be0ae68167de882743f70b86 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 27 Sep 2014 14:46:00 -0700 Subject: [PATCH 110/315] [SPARK-3543] Clean up Java TaskContext implementation. This addresses some minor issues in https://github.com/apache/spark/pull/2425 Author: Reynold Xin Closes #2557 from rxin/TaskContext and squashes the following commits: a51e5f6 [Reynold Xin] [SPARK-3543] Clean up Java TaskContext implementation. --- .../java/org/apache/spark/TaskContext.java | 33 ++++++++----------- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../apache/spark/scheduler/ResultTask.scala | 6 +--- .../spark/scheduler/ShuffleMapTask.scala | 2 -- .../org/apache/spark/scheduler/Task.scala | 8 +++-- 5 files changed, 22 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 09b8ce02bd3d8..4e6d708af0ea7 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -56,7 +56,7 @@ public class TaskContext implements Serializable { * @param taskMetrics performance metrics of the task */ @DeveloperApi - public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, + public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, TaskMetrics taskMetrics) { this.attemptId = attemptId; this.partitionId = partitionId; @@ -65,7 +65,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean this.taskMetrics = taskMetrics; } - /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. @@ -76,8 +75,7 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean * @param runningLocally whether the task is running locally in the driver JVM */ @DeveloperApi - public TaskContext(Integer stageId, Integer partitionId, Long attemptId, - Boolean runningLocally) { + public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) { this.attemptId = attemptId; this.partitionId = partitionId; this.runningLocally = runningLocally; @@ -85,7 +83,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, this.taskMetrics = TaskMetrics.empty(); } - /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. @@ -95,7 +92,7 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, * @param attemptId the number of attempts to execute this task */ @DeveloperApi - public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { + public TaskContext(int stageId, int partitionId, long attemptId) { this.attemptId = attemptId; this.partitionId = partitionId; this.runningLocally = false; @@ -107,9 +104,9 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { new ThreadLocal(); /** - * :: Internal API :: - * This is spark internal API, not intended to be called from user programs. - */ + * :: Internal API :: + * This is spark internal API, not intended to be called from user programs. + */ public static void setTaskContext(TaskContext tc) { taskContext.set(tc); } @@ -118,10 +115,8 @@ public static TaskContext get() { return taskContext.get(); } - /** - * :: Internal API :: - */ - public static void remove() { + /** :: Internal API :: */ + public static void unset() { taskContext.remove(); } @@ -130,22 +125,22 @@ public static void remove() { new ArrayList(); // Whether the corresponding task has been killed. - private volatile Boolean interrupted = false; + private volatile boolean interrupted = false; // Whether the task has completed. - private volatile Boolean completed = false; + private volatile boolean completed = false; /** * Checks whether the task has completed. */ - public Boolean isCompleted() { + public boolean isCompleted() { return completed; } /** * Checks whether the task has been killed. */ - public Boolean isInterrupted() { + public boolean isInterrupted() { return interrupted; } @@ -246,12 +241,12 @@ public long attemptId() { } @Deprecated - /** Deprecated: use getRunningLocally() */ + /** Deprecated: use isRunningLocally() */ public boolean runningLocally() { return runningLocally; } - public boolean getRunningLocally() { + public boolean isRunningLocally() { return runningLocally; } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 32cf29ed140e6..70c235dffff70 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -641,7 +641,7 @@ class DAGScheduler( job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.remove() + TaskContext.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 2ccbd8edeb028..4a9ff918afe25 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -58,11 +58,7 @@ private[spark] class ResultTask[T, U]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) metrics = Some(context.taskMetrics) - try { - func(context, rdd.iterator(partition, context)) - } finally { - context.markTaskCompleted() - } + func(context, rdd.iterator(partition, context)) } // This is only callable on the driver side. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index a98ee118254a3..79709089c0da4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -78,8 +78,6 @@ private[spark] class ShuffleMapTask( log.debug("Could not stop writer", e) } throw e - } finally { - context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index bf73f6f7bd0e1..c6e47c84a0cb2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -52,7 +52,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (_killed) { kill(interruptThread = false) } - runTask(context) + try { + runTask(context) + } finally { + context.markTaskCompleted() + TaskContext.unset() + } } def runTask(context: TaskContext): T @@ -93,7 +98,6 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - TaskContext.remove() } } From 248232936e1bead7f102e59eb8faf3126c582d9d Mon Sep 17 00:00:00 2001 From: Uri Laserson Date: Sat, 27 Sep 2014 21:48:05 -0700 Subject: [PATCH 111/315] [SPARK-3389] Add Converter for ease of Parquet reading in PySpark https://issues.apache.org/jira/browse/SPARK-3389 Author: Uri Laserson Closes #2256 from laserson/SPARK-3389 and squashes the following commits: 0ed363e [Uri Laserson] PEP8'd the python file 0b4b380 [Uri Laserson] Moved converter to examples and added python example eecf4dc [Uri Laserson] [SPARK-3389] Add Converter for ease of Parquet reading in PySpark --- .../src/main/python/parquet_inputformat.py | 59 ++++++++++++++ examples/src/main/resources/full_user.avsc | 1 + examples/src/main/resources/users.parquet | Bin 0 -> 615 bytes .../pythonconverters/AvroConverters.scala | 76 +++++++++++------- 4 files changed, 106 insertions(+), 30 deletions(-) create mode 100644 examples/src/main/python/parquet_inputformat.py create mode 100644 examples/src/main/resources/full_user.avsc create mode 100644 examples/src/main/resources/users.parquet diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py new file mode 100644 index 0000000000000..c9b08f878a1e6 --- /dev/null +++ b/examples/src/main/python/parquet_inputformat.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Read data file users.parquet in local Spark distro: + +$ cd $SPARK_HOME +$ export AVRO_PARQUET_JARS=/path/to/parquet-avro-1.5.0.jar +$ ./bin/spark-submit --driver-class-path /path/to/example/jar \\ + --jars $AVRO_PARQUET_JARS \\ + ./examples/src/main/python/parquet_inputformat.py \\ + examples/src/main/resources/users.parquet +<...lots of log output...> +{u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} +{u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} +<...more log output...> +""" +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, """ + Usage: parquet_inputformat.py + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar \\ + /path/to/examples/parquet_inputformat.py + Assumes you have Parquet data stored in . + """ + exit(-1) + + path = sys.argv[1] + sc = SparkContext(appName="ParquetInputFormat") + + parquet_rdd = sc.newAPIHadoopFile( + path, + 'parquet.avro.AvroParquetInputFormat', + 'java.lang.Void', + 'org.apache.avro.generic.IndexedRecord', + valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') + output = parquet_rdd.map(lambda x: x[1]).collect() + for k in output: + print k diff --git a/examples/src/main/resources/full_user.avsc b/examples/src/main/resources/full_user.avsc new file mode 100644 index 0000000000000..04e7ba2dca4f6 --- /dev/null +++ b/examples/src/main/resources/full_user.avsc @@ -0,0 +1 @@ +{"type": "record", "namespace": "example.avro", "name": "User", "fields": [{"type": "string", "name": "name"}, {"type": ["string", "null"], "name": "favorite_color"}, {"type": {"items": "int", "type": "array"}, "name": "favorite_numbers"}]} \ No newline at end of file diff --git a/examples/src/main/resources/users.parquet b/examples/src/main/resources/users.parquet new file mode 100644 index 0000000000000000000000000000000000000000..aa527338c43a8400fd56e549cb28aa1e6a9ccccf GIT binary patch literal 615 zcmZuv%WA?v6dhv>skOF(GbAMx8A#|N4V6|9aiOIPms03PD`l!<8_27ZC>8M^`h9*v zzoIu$LutDhxO2}r_i<*1{f8z-m}1MuG6X7C5vuhRgizmG#W5>FbjJgL&hf>LqokaZ zYYC8|l;VQV0B_^Ajmr=y801Dh!>c8wcbamJ;GDv#!@-jNG^p_p=0_fP*iwYfW6TA} zaK%KL95A1oK&zONR-LnDDBOfUPeU&hkZvLEEKddt|AmVfOQ~2gWv#@7U@Jsq-O#(1 zYT$})sz~1z#S)RpJsDWAfHh39mWmYpcax0PB|U2hw9kS8^O``r{L^;V4CrMtA|s$8 zM79MYBi+!Bv%TW!8~2&EEv#v>ia701!Ka~^QJbb)!ad!5e~TkFO;bOe0ch@WZx++e zczw`hQu|ObPJ|o0(v6+txjmU@P-546O!ri1zVJLc`A@QUG#BNAXU0Mr-ol4zs2e17 jvzcs=rbSG=FL-k0i^dXO!wrK*)46qS&=)u|gfI3D{z{mb literal 0 HcmV?d00001 diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 1b25983a38453..a11890d6f2b1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -30,21 +30,28 @@ import org.apache.spark.api.python.Converter import org.apache.spark.SparkException -/** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts - * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries - * to work with all 3 Avro data mappings (Generic, Specific and Reflect). - */ -class AvroWrapperToJavaConverter extends Converter[Any, Any] { - override def convert(obj: Any): Any = { +object AvroConversionUtil extends Serializable { + def fromAvro(obj: Any, schema: Schema): Any = { if (obj == null) { return null } - obj.asInstanceOf[AvroWrapper[_]].datum() match { - case null => null - case record: IndexedRecord => unpackRecord(record) - case other => throw new SparkException( - s"Unsupported top-level Avro data type ${other.getClass.getName}") + schema.getType match { + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj + case BOOLEAN => obj + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException( + s"Unknown Avro schema type ${other.getName}") } } @@ -103,28 +110,37 @@ class AvroWrapperToJavaConverter extends Converter[Any, Any] { "Unions may only consist of a concrete type and null") } } +} - def fromAvro(obj: Any, schema: Schema): Any = { +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro IndexedRecord (e.g., derived from AvroParquetInputFormat) to a Java Map. + */ +class IndexedRecordToJavaConverter extends Converter[IndexedRecord, JMap[String, Any]]{ + override def convert(record: IndexedRecord): JMap[String, Any] = { + if (record == null) { + return null + } + val map = new java.util.HashMap[String, Any] + AvroConversionUtil.unpackRecord(record) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries + * to work with all 3 Avro data mappings (Generic, Specific and Reflect). + */ +class AvroWrapperToJavaConverter extends Converter[Any, Any] { + override def convert(obj: Any): Any = { if (obj == null) { return null } - schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj - case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + obj.asInstanceOf[AvroWrapper[_]].datum() match { + case null => null + case record: IndexedRecord => AvroConversionUtil.unpackRecord(record) + case other => throw new SparkException( + s"Unsupported top-level Avro data type ${other.getClass.getName}") } } } From 9966d1a8aaed3d8cfed93855959705ea3c677215 Mon Sep 17 00:00:00 2001 From: Dale Date: Sat, 27 Sep 2014 22:08:10 -0700 Subject: [PATCH 112/315] SPARK-CORE [SPARK-3651] Group common CoarseGrainedSchedulerBackend variables together from [SPARK-3651] In CoarseGrainedSchedulerBackend, we have: private val executorActor = new HashMap[String, ActorRef] private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] private val freeCores = new HashMap[String, Int] private val totalCores = new HashMap[String, Int] We only ever put / remove stuff from these maps together. It would simplify the code if we consolidate these all into one map as we have done in JobProgressListener in https://issues.apache.org/jira/browse/SPARK-2299. Author: Dale Closes #2533 from tigerquoll/SPARK-3651 and squashes the following commits: d1be0a9 [Dale] [SPARK-3651] implemented suggested changes. Changed a reference from executorInfo to executorData to be consistent with other usages 6890663 [Dale] [SPARK-3651] implemented suggested changes 7d671cf [Dale] [SPARK-3651] Grouped variables under a ExecutorDataObject, and reference them via a map entry as they are all retrieved under the same key --- .../CoarseGrainedSchedulerBackend.scala | 68 ++++++++----------- .../scheduler/cluster/ExecutorData.scala | 38 +++++++++++ 2 files changed, 68 insertions(+), 38 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9a0cb1c6c6ccd..59e15edc75f5a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -62,15 +62,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A val createTime = System.currentTimeMillis() class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { - override protected def log = CoarseGrainedSchedulerBackend.this.log - - private val executorActor = new HashMap[String, ActorRef] - private val executorAddress = new HashMap[String, Address] - private val executorHost = new HashMap[String, String] - private val freeCores = new HashMap[String, Int] - private val totalCores = new HashMap[String, Int] private val addressToExecutorId = new HashMap[Address, String] + private val executorDataMap = new HashMap[String, ExecutorData] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -85,16 +79,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A def receiveWithLogging = { case RegisterExecutor(executorId, hostPort, cores) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorActor.contains(executorId)) { + if (executorDataMap.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredExecutor - executorActor(executorId) = sender - executorHost(executorId) = Utils.parseHostPort(hostPort)._1 - totalCores(executorId) = cores - freeCores(executorId) = cores - executorAddress(executorId) = sender.path.address + executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address, + Utils.parseHostPort(hostPort)._1, cores, cores)) + addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) @@ -104,13 +96,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - if (executorActor.contains(executorId)) { - freeCores(executorId) += scheduler.CPUS_PER_TASK - makeOffers(executorId) - } else { - // Ignoring the update since we don't know about the executor. - val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s" - logWarning(msg.format(taskId, state, sender, executorId)) + executorDataMap.get(executorId) match { + case Some(executorInfo) => + executorInfo.freeCores += scheduler.CPUS_PER_TASK + makeOffers(executorId) + case None => + // Ignoring the update since we don't know about the executor. + logWarning(s"Ignored task status update ($taskId state $state) " + + "from unknown executor $sender with ID $executorId") } } @@ -118,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A makeOffers() case KillTask(taskId, executorId, interruptThread) => - executorActor(executorId) ! KillTask(taskId, executorId, interruptThread) + executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread) case StopDriver => sender ! true @@ -126,8 +119,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A case StopExecutors => logInfo("Asking each executor to shut down") - for (executor <- executorActor.values) { - executor ! StopExecutor + for ((_, executorData) <- executorDataMap) { + executorData.executorActor ! StopExecutor } sender ! true @@ -138,6 +131,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) sender ! true + case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) @@ -149,13 +143,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorDataMap.map {case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores)}.toSeq)) } // Make fake resource offers on just one executor def makeOffers(executorId: String) { + val executorData = executorDataMap(executorId) launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) } // Launch tasks returned by a set of resource offers @@ -179,25 +175,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } } else { - freeCores(task.executorId) -= scheduler.CPUS_PER_TASK - executorActor(task.executorId) ! LaunchTask(new SerializableBuffer(serializedTask)) + val executorData = executorDataMap(task.executorId) + executorData.freeCores -= scheduler.CPUS_PER_TASK + executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) } } } // Remove a disconnected slave from the cluster def removeExecutor(executorId: String, reason: String) { - if (executorActor.contains(executorId)) { - logInfo("Executor " + executorId + " disconnected, so removing it") - val numCores = totalCores(executorId) - executorActor -= executorId - executorHost -= executorId - addressToExecutorId -= executorAddress(executorId) - executorAddress -= executorId - totalCores -= executorId - freeCores -= executorId - totalCoreCount.addAndGet(-numCores) - scheduler.executorLost(executorId, SlaveLost(reason)) + executorDataMap.get(executorId) match { + case Some(executorInfo) => + executorDataMap -= executorId + totalCoreCount.addAndGet(-executorInfo.totalCores) + scheduler.executorLost(executorId, SlaveLost(reason)) + case None => logError(s"Asked to remove non existant executor $executorId") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala new file mode 100644 index 0000000000000..74a92985b6629 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import akka.actor.{Address, ActorRef} + +/** + * Grouping of data that is accessed by a CourseGrainedScheduler. This class + * is stored in a Map keyed by an executorID + * + * @param executorActor The actorRef representing this executor + * @param executorAddress The network address of this executor + * @param executorHost The hostname that this executor is running on + * @param freeCores The current number of cores available for work on the executor + * @param totalCores The total number of cores available to the executor + */ +private[cluster] class ExecutorData( + val executorActor: ActorRef, + val executorAddress: Address, + val executorHost: String , + var freeCores: Int, + val totalCores: Int +) From 66e1c40c67f40dc4a5519812bc84877751933e7a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 27 Sep 2014 22:18:02 -0700 Subject: [PATCH 113/315] Minor fix for the previous commit. --- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 6 +++--- .../org/apache/spark/scheduler/cluster/ExecutorData.scala | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 59e15edc75f5a..89089e7d6f8a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -142,9 +142,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A // Make fake resource offers on all executors def makeOffers() { - launchTasks(scheduler.resourceOffers( - executorDataMap.map {case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores)}.toSeq)) + launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + }.toSeq)) } // Make fake resource offers on just one executor diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 74a92985b6629..b71bd5783d6df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -20,10 +20,9 @@ package org.apache.spark.scheduler.cluster import akka.actor.{Address, ActorRef} /** - * Grouping of data that is accessed by a CourseGrainedScheduler. This class - * is stored in a Map keyed by an executorID + * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The actorRef representing this executor + * @param executorActor The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor From 6918012d0f4841c5422b5827879a952428ec3a62 Mon Sep 17 00:00:00 2001 From: William Benton Date: Sun, 28 Sep 2014 01:01:27 -0700 Subject: [PATCH 114/315] SPARK-3699: SQL and Hive console tasks now clean up appropriately The sbt tasks sql/console and hive/console will now `stop()` the `SparkContext` upon exit. Previously, they left an ugly stack trace when quitting. Author: William Benton Closes #2547 from willb/consoleCleanup and squashes the following commits: d5e431f [William Benton] SQL and Hive console tasks now clean up. --- project/SparkBuild.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 12ac82293df76..01a5b20e7c51d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -221,7 +221,8 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, + cleanupCommands in console := "sparkContext.stop()" ) } @@ -249,7 +250,8 @@ object Hive { |import org.apache.spark.sql.execution |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, + cleanupCommands in console := "sparkContext.stop()" ) } From 1f13a40ccd5a869aec62788a1e345dc24fa648c8 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Sun, 28 Sep 2014 18:30:13 -0700 Subject: [PATCH 115/315] [SPARK-3715][Docs]minor typo https://issues.apache.org/jira/browse/SPARK-3715 Author: WangTaoTheTonic Closes #2567 from WangTaoTheTonic/minortypo and squashes the following commits: 9cc3f7a [WangTaoTheTonic] minor typo --- docs/sql-programming-guide.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 65249808fae3e..818fd5ab80af8 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -732,7 +732,7 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.

    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do +adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do not have an existing Hive deployment can still create a HiveContext. When not configured by the hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current directory. @@ -753,7 +753,7 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
    When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and -adds support for finding tables in in the MetaStore and writing queries using HiveQL. In addition to +adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. @@ -774,7 +774,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in in the MetaStore and writing queries using HiveQL. In addition to +adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. From 8e874185ed9efae8a1dc6b61d56ff401d72bb087 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 28 Sep 2014 18:33:11 -0700 Subject: [PATCH 116/315] Revert "[SPARK-1021] Defer the data-driven computation of partition bounds in so..." This reverts commit 2d972fd84ac54a89e416442508a6d4eaeff452c1. The commit was hanging correlationoptimizer14. --- .../scala/org/apache/spark/FutureAction.scala | 7 +- .../scala/org/apache/spark/Partitioner.scala | 29 ++------- .../apache/spark/rdd/AsyncRDDActions.scala | 64 ++++++++----------- 3 files changed, 34 insertions(+), 66 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index c277c3a47d421..75ea535f2f57b 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -208,7 +208,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R): R = { + resultFunc: => R) { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { @@ -223,10 +223,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { // cancel the job and stop the execution. This is not in a synchronized block because // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. try { - Await.ready(job, Duration.Inf).value.get match { - case scala.util.Failure(e) => throw e - case scala.util.Success(v) => v - } + Await.ready(job, Duration.Inf) } catch { case e: InterruptedException => job.cancel() diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index d40b152d221c5..37053bb6f37ad 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -29,10 +29,6 @@ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} -import org.apache.spark.SparkContext.rddToAsyncRDDActions -import scala.concurrent.Await -import scala.concurrent.duration.Duration - /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. @@ -117,12 +113,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions - @volatile private var valRB: Array[K] = null - - private def rangeBounds: Array[K] = this.synchronized { - if (valRB != null) return valRB - - valRB = if (partitions <= 1) { + private var rangeBounds: Array[K] = { + if (partitions <= 1) { Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. @@ -160,8 +152,6 @@ class RangePartitioner[K : Ordering : ClassTag, V]( RangePartitioner.determineBounds(candidates, partitions) } } - - valRB } def numPartitions = rangeBounds.length + 1 @@ -232,8 +222,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = this.synchronized { - if (valRB != null) return + private def readObject(in: ObjectInputStream) { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() @@ -245,7 +234,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( val ser = sfactory.newInstance() Utils.deserializeViaNestedStream(in, ser) { ds => implicit val classTag = ds.readObject[ClassTag[Array[K]]]() - valRB = ds.readObject[Array[K]]() + rangeBounds = ds.readObject[Array[K]]() } } } @@ -265,18 +254,12 @@ private[spark] object RangePartitioner { sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object - // use collectAsync here to run this job as a future, which is cancellable - val sketchFuture = rdd.mapPartitionsWithIndex { (idx, iter) => + val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => val seed = byteswap32(idx ^ (shift << 16)) val (sample, n) = SamplingUtils.reservoirSampleAndCount( iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) - }.collectAsync() - // We do need the future's value to continue any further - val sketched = Await.ready(sketchFuture, Duration.Inf).value.get match { - case scala.util.Success(v) => v.toArray - case scala.util.Failure(e) => throw e - } + }.collect() val numItems = sketched.map(_._2.toLong).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 7a68b3afa8158..b62f3fbdc4a15 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag -import org.apache.spark.util.Utils import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} import org.apache.spark.annotation.Experimental @@ -39,30 +38,29 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for counting the number of elements in the RDD. */ def countAsync(): FutureAction[Long] = { - val f = new ComplexFutureAction[Long] - f.run { - val totalCount = new AtomicLong - f.runJob(self, - (iter: Iterator[T]) => Utils.getIteratorSize(iter), - Range(0, self.partitions.size), - (index: Int, data: Long) => totalCount.addAndGet(data), - totalCount.get()) - } + val totalCount = new AtomicLong + self.context.submitJob( + self, + (iter: Iterator[T]) => { + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + }, + Range(0, self.partitions.size), + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) } /** * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val f = new ComplexFutureAction[Seq[T]] - f.run { - val results = new Array[Array[T]](self.partitions.size) - f.runJob(self, - (iter: Iterator[T]) => iter.toArray, - Range(0, self.partitions.size), - (index: Int, data: Array[T]) => results(index) = data, - results.flatten.toSeq) - } + val results = new Array[Array[T]](self.partitions.size) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + (index, data) => results(index) = data, results.flatten.toSeq) } /** @@ -106,34 +104,24 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } results.toSeq } + + f } /** * Applies a function f to all elements of this RDD. */ - def foreachAsync(expr: T => Unit): FutureAction[Unit] = { - val f = new ComplexFutureAction[Unit] - val exprClean = self.context.clean(expr) - f.run { - f.runJob(self, - (iter: Iterator[T]) => iter.foreach(exprClean), - Range(0, self.partitions.size), - (index: Int, data: Unit) => Unit, - Unit) - } + def foreachAsync(f: T => Unit): FutureAction[Unit] = { + val cleanF = self.context.clean(f) + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + (index, data) => Unit, Unit) } /** * Applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(expr: Iterator[T] => Unit): FutureAction[Unit] = { - val f = new ComplexFutureAction[Unit] - f.run { - f.runJob(self, - expr, - Range(0, self.partitions.size), - (index: Int, data: Unit) => Unit, - Unit) - } + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + (index, data) => Unit, Unit) } } From 25164a89dd32eef58d9b6823ae259439f796e81a Mon Sep 17 00:00:00 2001 From: Jim Lim Date: Sun, 28 Sep 2014 19:04:24 -0700 Subject: [PATCH 117/315] SPARK-2761 refactor #maybeSpill into Spillable Moved `#maybeSpill` in ExternalSorter and EAOM into `Spillable`. Author: Jim Lim Closes #2416 from jimjh/SPARK-2761 and squashes the following commits: cf8be9a [Jim Lim] SPARK-2761 fix documentation, reorder code f94d522 [Jim Lim] SPARK-2761 refactor Spillable to simplify sig e75a24e [Jim Lim] SPARK-2761 use protected over protected[this] 7270e0d [Jim Lim] SPARK-2761 refactor #maybeSpill into Spillable --- .../collection/ExternalAppendOnlyMap.scala | 46 ++------ .../util/collection/ExternalSorter.scala | 68 +++-------- .../spark/util/collection/Spillable.scala | 111 ++++++++++++++++++ 3 files changed, 133 insertions(+), 92 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/Spillable.scala diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a015c1d26a96..0c088da46aa5e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -66,23 +66,19 @@ class ExternalAppendOnlyMap[K, V, C]( mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager) - extends Iterable[(K, C)] with Serializable with Logging { + extends Iterable[(K, C)] + with Serializable + with Logging + with Spillable[SizeTracker] { private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Number of pairs inserted since last spill; note that we count them even if a value is merged // with a previous key in case we're doing something like groupBy where the result grows - private var elementsRead = 0L - - // Number of in-memory pairs inserted before tracking the map's shuffle memory usage - private val trackMemoryThreshold = 1000 - - // How much of the shared memory pool this collection has claimed - private var myMemoryThreshold = 0L + protected[this] var elementsRead = 0L /** * Size of object batches when reading/writing from serializers. @@ -95,11 +91,7 @@ class ExternalAppendOnlyMap[K, V, C]( */ private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) - // How many times we have spilled so far - private var spillCount = 0 - // Number of bytes spilled in total - private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 @@ -136,19 +128,8 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - currentMap.estimateSize() >= myMemoryThreshold) - { - // Claim up to double our current memory from the shuffle memory pool - val currentMemory = currentMap.estimateSize() - val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) - myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - spill(currentMemory) // Will also release memory back to ShuffleMemoryManager - } + if (maybeSpill(currentMap, currentMap.estimateSize())) { + currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) elementsRead += 1 @@ -171,11 +152,7 @@ class ExternalAppendOnlyMap[K, V, C]( /** * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ - private def spill(mapSize: Long): Unit = { - spillCount += 1 - val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" - .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + override protected[this] def spill(collection: SizeTracker): Unit = { val (blockId, file) = diskBlockManager.createTempBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, @@ -231,18 +208,11 @@ class ExternalAppendOnlyMap[K, V, C]( } } - currentMap = new SizeTrackingAppendOnlyMap[K, C] spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0L - elementsRead = 0 - _memoryBytesSpilled += mapSize } - def memoryBytesSpilled: Long = _memoryBytesSpilled def diskBytesSpilled: Long = _diskBytesSpilled /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 782b979e2e93d..0a152cb97ad9e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -79,14 +79,14 @@ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, - serializer: Option[Serializer] = None) extends Logging { + serializer: Option[Serializer] = None) + extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] { private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager - private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() @@ -115,22 +115,14 @@ private[spark] class ExternalSorter[K, V, C]( // Number of pairs read from input since last spill; note that we count them even if a value is // merged with a previous key in case we're doing something like groupBy where the result grows - private var elementsRead = 0L - - // What threshold of elementsRead we start estimating map size at. - private val trackMemoryThreshold = 1000 + protected[this] var elementsRead = 0L // Total spilling statistics - private var spillCount = 0 - private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ - // How much of the shared memory pool this collection has claimed - private var myMemoryThreshold = 0L - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need // local aggregation and sorting, write numPartitions files directly and just concatenate them // at the end. This avoids doing serialization and deserialization twice to merge together the @@ -209,7 +201,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsRead += 1 kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) - maybeSpill(usingMap = true) + maybeSpillCollection(usingMap = true) } } else { // Stick values into our buffer @@ -217,7 +209,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsRead += 1 val kv = records.next() buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - maybeSpill(usingMap = false) + maybeSpillCollection(usingMap = false) } } } @@ -227,61 +219,31 @@ private[spark] class ExternalSorter[K, V, C]( * * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - private def maybeSpill(usingMap: Boolean): Unit = { + private def maybeSpillCollection(usingMap: Boolean): Unit = { if (!spillingEnabled) { return } - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - - // TODO: factor this out of both here and ExternalAppendOnlyMap - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - collection.estimateSize() >= myMemoryThreshold) - { - // Claim up to double our current memory from the shuffle memory pool - val currentMemory = collection.estimateSize() - val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) - myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager + if (usingMap) { + if (maybeSpill(map, map.estimateSize())) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } + } else { + if (maybeSpill(buffer, buffer.estimateSize())) { + buffer = new SizeTrackingPairBuffer[(Int, K), C] } } } /** * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. - * - * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - private def spill(memorySize: Long, usingMap: Boolean): Unit = { - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - val memorySize = collection.estimateSize() - - spillCount += 1 - val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" - .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) - + override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { if (bypassMergeSort) { spillToPartitionFiles(collection) } else { spillToMergeableFile(collection) } - - if (usingMap) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] - } else { - buffer = new SizeTrackingPairBuffer[(Int, K), C] - } - - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0 - - _memoryBytesSpilled += memorySize } /** @@ -804,8 +766,6 @@ private[spark] class ExternalSorter[K, V, C]( } } - def memoryBytesSpilled: Long = _memoryBytesSpilled - def diskBytesSpilled: Long = _diskBytesSpilled /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala new file mode 100644 index 0000000000000..d7dccd4af8c6e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.apache.spark.Logging +import org.apache.spark.SparkEnv + +/** + * Spills contents of an in-memory collection to disk when the memory threshold + * has been exceeded. + */ +private[spark] trait Spillable[C] { + + this: Logging => + + /** + * Spills the current in-memory collection to disk, and releases the memory. + * + * @param collection collection to spill to disk + */ + protected def spill(collection: C): Unit + + // Number of elements read from input since last spill + protected var elementsRead: Long + + // Memory manager that can be used to acquire/release memory + private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + + // What threshold of elementsRead we start estimating collection size at + private[this] val trackMemoryThreshold = 1000 + + // How much of the shared memory pool this collection has claimed + private[this] var myMemoryThreshold = 0L + + // Number of bytes spilled in total + private[this] var _memoryBytesSpilled = 0L + + // Number of spills + private[this] var _spillCount = 0 + + /** + * Spills the current in-memory collection to disk if needed. Attempts to acquire more + * memory before spilling. + * + * @param collection collection to spill to disk + * @param currentMemory estimated size of the collection in bytes + * @return true if `collection` was spilled to disk; false otherwise + */ + protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { + if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && + currentMemory >= myMemoryThreshold) { + // Claim up to double our current memory from the shuffle memory pool + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + _spillCount += 1 + logSpillage(currentMemory) + + spill(collection) + + // Keep track of spills, and release memory + _memoryBytesSpilled += currentMemory + releaseMemoryForThisThread() + return true + } + } + false + } + + /** + * @return number of bytes spilled in total + */ + def memoryBytesSpilled: Long = _memoryBytesSpilled + + /** + * Release our memory back to the shuffle pool so that other threads can grab it. + */ + private def releaseMemoryForThisThread(): Unit = { + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L + } + + /** + * Prints a standard log message detailing spillage. + * + * @param size number of bytes spilled + */ + @inline private def logSpillage(size: Long) { + val threadId = Thread.currentThread().getId + logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(threadId, size / (1024 * 1024), _spillCount, if (_spillCount > 1) "s" else "")) + } +} From f350cd307045c2c02e713225d8f1247f18ba123e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 28 Sep 2014 20:32:54 -0700 Subject: [PATCH 118/315] [SPARK-3543] TaskContext remaining cleanup work. Author: Reynold Xin Closes #2560 from rxin/TaskContext and squashes the following commits: 9eff95a [Reynold Xin] [SPARK-3543] remaining cleanup work. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 3 ++- .../apache/spark/util/JavaTaskCompletionListenerImpl.java | 7 +++---- .../serializer/ProactiveClosureSerializationSuite.scala | 6 +----- .../apache/spark/sql/parquet/ParquetTableOperations.scala | 4 ++-- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 036dcc49664ef..21d0cc7b5cbaa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -194,7 +194,7 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 7f578bc5dac39..67833743f3a98 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -86,7 +86,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitionsWithContext((context, iter) => { + self.mapPartitions(iter => { + val context = TaskContext.get() new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) } else { diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index af34cdb03e4d1..0944bf8cd5c71 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,10 +30,9 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.runningLocally(); - context.taskMetrics(); + context.getStageId(); + context.getPartitionId(); + context.isRunningLocally(); context.addTaskCompletionListener(this); } } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index aad6599589420..d037e2c19a64d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -50,8 +50,7 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex "flatMap" -> xflatMap _, "filter" -> xfilter _, "mapPartitions" -> xmapPartitions _, - "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, - "mapPartitionsWithContext" -> xmapPartitionsWithContext _)) { + "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) { val (name, xf) = transformation test(s"$name transformations throw proactive serialization exceptions") { @@ -78,8 +77,5 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) - - private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index d39e31a7fa195..ffb732347d30a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -289,9 +289,9 @@ case class InsertIntoParquetTable( def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) From 0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 28 Sep 2014 21:44:50 -0700 Subject: [PATCH 119/315] [SPARK-1545] [mllib] Add Random Forests This PR adds RandomForest to MLlib. The implementation is basic, and future performance optimizations will be important. (Note: RFs = Random Forests.) # Overview ## RandomForest * trains multiple trees at once to reduce the number of passes over the data * allows feature subsets at each node * uses a queue of nodes instead of fixed groups for each level This implementation is based an implementation by manishamde and the [Alpine Labs Sequoia Forest](https://github.com/AlpineNow/SparkML2) by codedeft (in particular, the TreePoint, BaggedPoint, and node queue implementations). Thank you for your inputs! ## Testing Correctness: This has been tested for correctness with the test suites and with DecisionTreeRunner on example datasets. Performance: This has been performance tested using [this branch of spark-perf](https://github.com/jkbradley/spark-perf/tree/rfs). Results below. ### Regression tests for DecisionTree Summary: For training 1 tree, there are small regressions, especially from feature subsampling. In the table below, each row is a single (random) dataset. The 2 different sets of result columns are for 2 different RF implementations: * (numTrees): This is from an earlier commit, after implementing RandomForest to train multiple trees at once. It does not include any code for feature subsampling. * (feature subsets): This is from this current PR's code, after implementing feature subsampling. These tests were to identify regressions in DecisionTree, so they are training 1 tree with all of the features (i.e., no feature subsampling). These were run on an EC2 cluster with 15 workers, training 1 tree with maxDepth = 5 (= 6 levels). Speedup values < 1 indicate slowdowns from the old DecisionTree implementation. numInstances | numFeatures | runtime (sec) | speedup | runtime (sec) | speedup ---- | ---- | ---- | ---- | ---- | ---- | | (numTrees) | (numTrees) | (feature subsets) | (feature subsets) 20000 | 100 | 4.051 | 1.044433473 | 4.478 | 0.9448414471 20000 | 500 | 8.472 | 1.104461756 | 9.315 | 1.004508857 20000 | 1500 | 19.354 | 1.05854087 | 20.863 | 0.9819776638 20000 | 3500 | 43.674 | 1.072033704 | 45.887 | 1.020332556 200000 | 100 | 4.196 | 1.171830315 | 4.848 | 1.014232673 200000 | 500 | 8.926 | 1.082791844 | 9.771 | 0.989151571 200000 | 1500 | 20.58 | 1.068415938 | 22.134 | 0.9934038131 200000 | 3500 | 48.043 | 1.075203464 | 52.249 | 0.9886505005 2000000 | 100 | 4.944 | 1.01355178 | 5.796 | 0.8645617667 2000000 | 500 | 11.11 | 1.016831683 | 12.482 | 0.9050632911 2000000 | 1500 | 31.144 | 1.017852556 | 35.274 | 0.8986789136 2000000 | 3500 | 79.981 | 1.085382778 | 101.105 | 0.8586123337 20000000 | 100 | 8.304 | 0.9270231214 | 9.073 | 0.8484514494 20000000 | 500 | 28.174 | 1.083268262 | 34.236 | 0.8914592826 20000000 | 1500 | 143.97 | 0.9579634646 | 159.275 | 0.8659111599 ### Tests for forests I have run other tests with numTrees=10 and with sqrt(numFeatures), and those indicate that multi-model training and feature subsets can speed up training for forests, especially when training deeper trees. # Details on specific classes ## Changes to DecisionTree * Main train() method is now in RandomForest. * findBestSplits() is no longer needed. (It split levels into groups, but we now use a queue of nodes.) * Many small changes to support RFs. (Note: These methods should be moved to RandomForest.scala in a later PR, but are in DecisionTree.scala to make code comparison easier.) ## RandomForest * Main train() method is from old DecisionTree. * selectNodesToSplit: Note that it selects nodes and feature subsets jointly to track memory usage. ## RandomForestModel * Stores an Array[DecisionTreeModel] * Prediction: * For classification, most common label. For regression, mean. * We could support other methods later. ## examples/.../DecisionTreeRunner * This now takes numTrees and featureSubsetStrategy, to support RFs. ## DTStatsAggregator * 2 types of functionality (w/ and w/o subsampling features): These require different indexing methods. (We could treat both as subsampling, but this is less efficient DTStatsAggregator is now abstract, and 2 child classes implement these 2 types of functionality. ## impurities * These now take instance weights. ## Node * Some vals changed to vars. * This is unfortunately a public API change (DeveloperApi). This could be avoided by creating a LearningNode struct, but would be awkward. ## RandomForestSuite Please let me know if there are missing tests! ## BaggedPoint This wraps TreePoint and holds bootstrap weights/counts. # Design decisions * BaggedPoint: BaggedPoint is separate from TreePoint since it may be useful for other bagging algorithms later on. * RandomForest public API: What options should be easily supported by the train* methods? Should ALL options be in the Java-friendly constructors? Should there be a constructor taking Strategy? * Feature subsampling options: What options should be supported? scikit-learn supports the same options, except for "onethird." One option would be to allow users to specific fractions ("0.1"): the current options could be supported, and any unrecognized values would be parsed as Doubles in [0,1]. * Splits and bins are computed before bootstrapping, so all trees use the same discretization. * One queue, instead of one queue per tree. CC: mengxr manishamde codedeft chouqin Please let me know if you have suggestions---thanks! Author: Joseph K. Bradley Author: qiping.lqp Author: chouqin Closes #2435 from jkbradley/rfs-new and squashes the following commits: c694174 [Joseph K. Bradley] Fixed typo cc59d78 [Joseph K. Bradley] fixed imports e25909f [Joseph K. Bradley] Simplified node group maps. Specifically, created NodeIndexInfo to store node index in agg and feature subsets, and no longer create extra maps in findBestSplits fbe9a1e [Joseph K. Bradley] Changed default featureSubsetStrategy to be sqrt for classification, onethird for regression. Updated docs with references. ef7c293 [Joseph K. Bradley] Updates based on code review. Most substantial changes: * Simplified DTStatsAggregator * Made RandomForestModel.trees public * Added test for regression to RandomForestSuite 593b13c [Joseph K. Bradley] Fixed bug in metadata for computing log2(num features). Now it checks >= 1. a1a08df [Joseph K. Bradley] Removed old comments 866e766 [Joseph K. Bradley] Changed RandomForestSuite randomized tests to use multiple fixed random seeds. ff8bb96 [Joseph K. Bradley] removed usage of null from RandomForest and replaced with Option bf1a4c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 6b79c07 [Joseph K. Bradley] Added RandomForestSuite, and fixed small bugs, style issues. d7753d4 [Joseph K. Bradley] Added numTrees and featureSubsetStrategy to DecisionTreeRunner (to support RandomForest). Fixed bugs so that RandomForest now runs. 746d43c [Joseph K. Bradley] Implemented feature subsampling. Tested DecisionTree but not RandomForest. 6309d1d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new. Added RandomForestModel.toString b7ae594 [Joseph K. Bradley] Updated docs. Small fix for bug which does not cause errors: No longer allocate unused child nodes for leaf nodes. 121c74e [Joseph K. Bradley] Basic random forests are implemented. Random features per node not yet implemented. Test suite not implemented. 325d18a [Joseph K. Bradley] Merge branch 'chouqin-dt-preprune' into rfs-new 4ef9bf1 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 6da8571 [Joseph K. Bradley] RFs partly implemented, not done yet eddd1eb [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1 0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code: efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree --- .../examples/mllib/DecisionTreeRunner.scala | 76 ++- .../spark/mllib/tree/DecisionTree.scala | 457 ++++++------------ .../spark/mllib/tree/RandomForest.scala | 451 +++++++++++++++++ .../spark/mllib/tree/impl/BaggedPoint.scala | 80 +++ .../mllib/tree/impl/DTStatsAggregator.scala | 219 +++++++-- .../tree/impl/DecisionTreeMetadata.scala | 47 +- .../spark/mllib/tree/impurity/Entropy.scala | 4 +- .../spark/mllib/tree/impurity/Gini.scala | 4 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 8 +- .../apache/spark/mllib/tree/model/Node.scala | 13 +- .../mllib/tree/model/RandomForestModel.scala | 105 ++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 210 ++++---- .../spark/mllib/tree/RandomForestSuite.scala | 245 ++++++++++ 14 files changed, 1410 insertions(+), 511 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4683e6eb966be..96fb068e9e126 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,16 +21,18 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, impurity} +import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** - * An example runner for decision tree. Run with + * An example runner for decision trees and random forests. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} @@ -57,6 +59,8 @@ object DecisionTreeRunner { maxBins: Int = 32, minInstancesPerNode: Int = 1, minInfoGain: Double = 0.0, + numTrees: Int = 1, + featureSubsetStrategy: String = "auto", fracTest: Double = 0.2) def main(args: Array[String]) { @@ -79,11 +83,20 @@ object DecisionTreeRunner { .action((x, c) => c.copy(maxBins = x)) opt[Int]("minInstancesPerNode") .text(s"min number of instances required at child nodes to create the parent split," + - s" default: ${defaultParams.minInstancesPerNode}") + s" default: ${defaultParams.minInstancesPerNode}") .action((x, c) => c.copy(minInstancesPerNode = x)) opt[Double]("minInfoGain") .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("numTrees") + .text(s"number of trees (1 = decision tree, 2+ = random forest)," + + s" default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(numTrees = x)) + opt[String]("featureSubsetStrategy") + .text(s"feature subset sampling strategy" + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s"default: ${defaultParams.featureSubsetStrategy}") + .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) @@ -191,18 +204,35 @@ object DecisionTreeRunner { numClassesForClassification = numClasses, minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain) - val model = DecisionTree.train(training, strategy) - - println(model) - - if (params.algo == Classification) { - val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy") - } - - if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse") + if (params.numTrees == 1) { + val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { + val accuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $accuracy") + } + if (params.algo == Regression) { + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse") + } + } else { + val randomSeed = Utils.random.nextInt() + if (params.algo == Classification) { + val model = RandomForest.trainClassifier(training, strategy, params.numTrees, + params.featureSubsetStrategy, randomSeed) + println(model) + val accuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $accuracy") + } + if (params.algo == Regression) { + val model = RandomForest.trainRegressor(training, strategy, params.numTrees, + params.featureSubsetStrategy, randomSeed) + println(model) + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse") + } } sc.stop() @@ -211,9 +241,7 @@ object DecisionTreeRunner { /** * Calculates the classifier accuracy. */ - private def accuracyScore( - model: DecisionTreeModel, - data: RDD[LabeledPoint]): Double = { + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() correctCount.toDouble / count @@ -228,4 +256,14 @@ object DecisionTreeRunner { err * err }.mean() } + + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c7f2576c822b1..b7dc373ebd9cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,12 +18,14 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -33,7 +35,6 @@ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom @@ -56,99 +57,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - - val timer = new TimeTracker() - - timer.start("total") - - timer.start("init") - - val retaggedInput = input.retag(classOf[LabeledPoint]) - val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) - logDebug("algo = " + strategy.algo) - logDebug("maxBins = " + metadata.maxBins) - - // Find the splits and the corresponding bins (interval between the splits) using a sample - // of the input data. - timer.start("findSplitsBins") - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - timer.stop("findSplitsBins") - logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - }.mkString("\n")) - - // Bin feature values (TreePoint representation). - // Cache input RDD for speedup during multiple passes. - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - .persist(StorageLevel.MEMORY_AND_DISK) - - // depth of the decision tree - val maxDepth = strategy.maxDepth - require(maxDepth <= 30, - s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") - - // Calculate level for single group construction - - // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L - logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - // TODO: Calculate memory usage more precisely. - val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) - - logDebug("numElementsPerNode = " + numElementsPerNode) - val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array - val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) - logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) - // nodes at a level is 2^level. level is zero indexed. - val maxLevelForSingleGroup = math.max( - (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) - logDebug("max level for single group = " + maxLevelForSingleGroup) - - timer.stop("init") - - /* - * The main idea here is to perform level-wise training of the decision tree nodes thus - * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is handled by a particular node at that level (or it reaches a leaf - * beforehand and is not used in later levels. - */ - - var topNode: Node = null // set on first iteration - var level = 0 - var break = false - while (level <= maxDepth && !break) { - logDebug("#####################################") - logDebug("level = " + level) - logDebug("#####################################") - - // Find best split for all nodes at a level. - timer.start("findBestSplits") - val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput, - metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer) - timer.stop("findBestSplits") - - if (level == 0) { - topNode = tmpTopNode - } - if (doneTraining) { - break = true - logDebug("done training") - } - - level += 1 - } - - logDebug("#####################################") - logDebug("Extracting tree model") - logDebug("#####################################") - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - new DecisionTreeModel(topNode, strategy.algo) + // Note: random seed will not be used since numTrees = 1. + val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) + val rfModel = rf.train(input) + rfModel.trees(0) } } @@ -352,58 +264,10 @@ object DecisionTree extends Serializable with Logging { impurity, maxDepth, maxBins) } - /** - * Returns an array of optimal splits for all nodes at a given level. Splits the task into - * multiple groups if the level-wise training task could lead to memory overflow. - * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param metadata Learning and dataset metadata - * @param level Level of the tree - * @param topNode Root node of the tree (or invalid node when training first level). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return (root, doneTraining) where: - * root = Root node (which is newly created on the first iteration), - * doneTraining = true if no more internal nodes were created. - */ - private[tree] def findBestSplits( - input: RDD[TreePoint], - metadata: DecisionTreeMetadata, - level: Int, - topNode: Node, - splits: Array[Array[Split]], - bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): (Node, Boolean) = { - - // split into groups to avoid memory overflow during aggregation - if (level > maxLevelForSingleGroup) { - // When information for all nodes at a given level cannot be stored in memory, - // the nodes are divided into multiple groups at each level with the number of groups - // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, - // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = 1 << level - maxLevelForSingleGroup - logDebug("numGroups = " + numGroups) - // Iterate over each group of nodes at a level. - var groupIndex = 0 - var doneTraining = true - while (groupIndex < numGroups) { - val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, - topNode, splits, bins, timer, numGroups, groupIndex) - doneTraining = doneTraining && doneTrainingGroup - groupIndex += 1 - } - (topNode, doneTraining) // Not first iteration, so topNode was already set. - } else { - findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer) - } - } - /** * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a node - * at the current level being trained; that node's index is returned. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. * * @param node Node in tree from which to classify the given data point. * @param binnedFeatures Binned feature vector for data point. @@ -413,14 +277,15 @@ object DecisionTree extends Serializable with Logging { * Otherwise, last node reachable in tree matching this example. * Note: This is the global node index, i.e., the index used in the tree. * This index is different from the index used during training a particular - * set of nodes in a (level, group). + * group of nodes on one call to [[findBestSplits()]]. */ private def predictNodeIndex( node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = { - if (node.isLeaf) { + if (node.isLeaf || node.split.isEmpty) { + // Node is either leaf, or has not yet been split. node.id } else { val featureIndex = node.split.get.feature @@ -465,43 +330,60 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, nodeIndex: Int, bins: Array[Array[Bin]], - unorderedFeatures: Set[Int]): Unit = { - // Iterate over all features. - val numFeatures = treePoint.binnedFeatures.size + unorderedFeatures: Set[Int], + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val numFeaturesPerNode = if (featuresForNode.nonEmpty) { + // Use subsampled features + featuresForNode.get.size + } else { + // Use all features + agg.metadata.numFeatures + } val nodeOffset = agg.getNodeOffset(nodeIndex) - var featureIndex = 0 - while (featureIndex < numFeatures) { + // Iterate over features. + var featureIndexIdx = 0 + while (featureIndexIdx < numFeaturesPerNode) { + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) // Update the left or right bin for each split. - val numSplits = agg.numSplits(featureIndex) + val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { - agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label) + agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, + instanceWeight) } else { - agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label) + agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, + instanceWeight) } splitIndex += 1 } } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label, + instanceWeight) } - featureIndex += 1 + featureIndexIdx += 1 } } @@ -513,66 +395,77 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @return agg + * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). + * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int): Unit = { + nodeIndex: Int, + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label val nodeOffset = agg.getNodeOffset(nodeIndex) - // Iterate over all features. - val numFeatures = agg.numFeatures - var featureIndex = 0 - while (featureIndex < numFeatures) { - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label) - featureIndex += 1 + // Iterate over features. + if (featuresForNode.nonEmpty) { + // Use subsampled features + var featureIndexIdx = 0 + while (featureIndexIdx < featuresForNode.get.size) { + val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight) + featureIndexIdx += 1 + } + } else { + // Use all features + val numFeatures = agg.metadata.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight) + featureIndex += 1 + } } } /** - * Returns an array of optimal splits for a group of nodes at a given level + * Given a group of nodes, this finds the best split for each node. * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata - * @param level Level of the tree - * @param topNode Root node of the tree (or invalid node when training first level). + * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param numGroups total number of node groups at the current level. Default value is set to 1. - * @param groupIndex index of the node group being processed. Default value is set to 0. - * @return (root, doneTraining) where: - * root = Root node (which is newly created on the first iteration), - * doneTraining = true if no more internal nodes were created. + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. */ - private def findBestSplitsPerGroup( - input: RDD[TreePoint], + private[tree] def findBestSplits( + input: RDD[BaggedPoint[TreePoint]], metadata: DecisionTreeMetadata, - level: Int, - topNode: Node, + topNodes: Array[Node], + nodesForGroup: Map[Int, Array[Node]], + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], bins: Array[Array[Bin]], - timer: TimeTracker, - numGroups: Int = 1, - groupIndex: Int = 0): (Node, Boolean) = { + nodeQueue: mutable.Queue[(Int, Node)], + timer: TimeTracker = new TimeTracker): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. * - * *Level-wise training* - * We perform bin calculations for all nodes at the given level to avoid making multiple - * passes over the data. Thus, for a slightly increased computation and storage cost we save - * several iterations over the data especially at higher levels of the decision tree. + * *Group-wise training* + * We perform bin calculations for groups of nodes to reduce the number of + * passes over the data. Each iteration requires more computation and storage, + * but saves several iterations over the data. * * *Bin-wise computation* * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, - * is ordered (read ordering for categorical variables in the findSplitsBins method), - * we exploit this structure to calculate aggregates for bins and then use these aggregates + * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. * * *Aggregation over partitions* @@ -582,42 +475,15 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // Common calculations for multiple nested methods: - - // numNodes: Number of nodes in this (level of tree, group), - // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = Node.maxNodesInLevel(level) / numGroups + // numNodes: Number of nodes in this group + val numNodes = nodesForGroup.values.map(_.size).sum logDebug("numNodes = " + numNodes) - logDebug("numFeatures = " + metadata.numFeatures) logDebug("numClasses = " + metadata.numClasses) logDebug("isMulticlass = " + metadata.isMulticlass) logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) - // shift when more than one group is used at deep tree level - val groupShift = numNodes * groupIndex - - // Used for treePointToNodeIndex to get an index for this (level, group). - // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level. - // - groupShift corrects for groups in this level before the current group. - val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift - - /** - * Find the node index for the given example. - * Nodes are indexed from 0 at the start of this (level, group). - * If the example does not reach this level, returns a value < 0. - */ - def treePointToNodeIndex(treePoint: TreePoint): Int = { - if (level == 0) { - 0 - } else { - val globalNodeIndex = - predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures) - globalNodeIndex - globalNodeIndexOffset - } - } - /** * Performs a sequential aggregation over a partition. * @@ -626,21 +492,27 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). - * @param treePoint Data point being aggregated. + * @param baggedPoint Data point being aggregated. * @return agg */ def binSeqOp( agg: DTStatsAggregator, - treePoint: TreePoint): DTStatsAggregator = { - val nodeIndex = treePointToNodeIndex(treePoint) - // If the example does not reach this level, then nodeIndex < 0. - // If the example reaches this level but is handled in a different group, - // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). - if (nodeIndex >= 0 && nodeIndex < numNodes) { - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg, treePoint, nodeIndex) - } else { - mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) + baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, + bins, metadata.unorderedFeatures) + val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null) + // If the example does not reach a node in this group, then nodeIndex = null. + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures, + instanceWeight, featuresForNode) + } } } agg @@ -649,71 +521,62 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregates. timer.start("aggregation") val binAggregates: DTStatsAggregator = { - val initAgg = new DTStatsAggregator(metadata, numNodes) + val initAgg = if (metadata.subsamplingFeatures) { + new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo) + } else { + new DTStatsAggregatorFixedFeatures(metadata, numNodes) + } input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) } timer.stop("aggregation") - // Calculate best splits for all nodes at a given level + // Calculate best splits for all nodes in the group timer.start("chooseSplits") - // On the first iteration, we need to get and return the newly created root node. - var newTopNode: Node = topNode - - // Iterate over all nodes at this level - var nodeIndex = 0 - var internalNodeCount = 0 - while (nodeIndex < numNodes) { - val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits) - logDebug("best split = " + split) - - val globalNodeIndex = globalNodeIndexOffset + nodeIndex - // Extract info for this node at the current level. - val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth) - val node = - new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - - if (!isLeaf) { - internalNodeCount += 1 - } - if (level == 0) { - newTopNode = node - } else { - // Set parent. - val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode) - if (Node.isLeftChild(globalNodeIndex)) { - parentNode.leftNode = Some(node) - } else { - parentNode.rightNode = Some(node) + // Iterate over all nodes in this group. + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode) + logDebug("best split = " + split) + + // Extract info for this node. Create children if not leaf. + val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) + assert(node.id == nodeIndex) + node.predict = predict.predict + node.isLeaf = isLeaf + node.stats = Some(stats) + logDebug("Node = " + node) + + if (!isLeaf) { + node.split = Some(split) + node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) + node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + logDebug("leftChildIndex = " + node.leftNode.get.id + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + node.rightNode.get.id + + ", impurity = " + stats.rightImpurity) } } - if (level < metadata.maxDepth) { - logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) + - ", impurity = " + stats.leftImpurity) - logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) + - ", impurity = " + stats.rightImpurity) - } - - nodeIndex += 1 } timer.stop("chooseSplits") - - val doneTraining = internalNodeCount == 0 - (newTopNode, doneTraining) } /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for all splits + * @return information gain and statistics for split */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -753,7 +616,7 @@ object DecisionTree extends Serializable with Logging { * Calculate predict value for current node, given stats of any split. * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a node + * @param rightImpurityCalculator right node aggregates for a split * @return predict value for current node */ private def calculatePredict( @@ -770,27 +633,33 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. * @param binAggregates Bin statistics. - * @param nodeIndex Index for node to split in this (level, group). - * @return tuple for best split: (Split, information gain) + * @param nodeIndex Index into aggregates for node to split in this group. + * @return tuple for best split: (Split, information gain, prediction at node) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, - level: Int, - metadata: DecisionTreeMetadata, - splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + + val metadata: DecisionTreeMetadata = binAggregates.metadata // calculate predict only once var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex => + val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx => + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } val numSplits = metadata.numSplits(featureIndex) if (metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) var splitIndex = 0 while (splitIndex < numSplits) { binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) @@ -803,26 +672,26 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) val numBins = metadata.numBins(featureIndex) /* Each bin is one category (feature value). @@ -887,7 +756,7 @@ object DecisionTree extends Serializable with Logging { binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -903,18 +772,6 @@ object DecisionTree extends Serializable with Logging { (bestSplit, bestSplitStats, predict.get) } - /** - * Get the number of values to be stored per node in the bin aggregates. - */ - private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = { - val totalBins = metadata.numBins.map(_.toLong).sum - if (metadata.isClassification) { - metadata.numClasses * totalBins - } else { - 3 * totalBins - } - } - /** * Returns splits and bins for decision tree calculation. * Continuous and categorical features are handled differently. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala new file mode 100644 index 0000000000000..7fa7725e79e46 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker} +import org.apache.spark.mllib.tree.impurity.Impurities +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * A class which implements a random forest learning algorithm for classification and regression. + * It supports both continuous and categorical features. + * + * The settings for featureSubsetStrategy are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @param strategy The configuration parameters for the random forest algorithm which specify + * the type of algorithm (classification, regression, etc.), feature type + * (continuous, categorical), depth of the tree, quantile calculation strategy, + * etc. + * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + */ +@Experimental +private class RandomForest ( + private val strategy: Strategy, + private val numTrees: Int, + featureSubsetStrategy: String, + private val seed: Int) + extends Serializable with Logging { + + strategy.assertValid() + require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") + require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), + s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + + s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.") + + /** + * Method to train a decision tree model over an RDD + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return RandomForestModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): RandomForestModel = { + + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + logDebug("algo = " + strategy.algo) + logDebug("numTrees = " + numTrees) + logDebug("seed = " + seed) + logDebug("maxBins = " + metadata.maxBins) + logDebug("featureSubsetStrategy = " + featureSubsetStrategy) + logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + timer.start("findSplitsBins") + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) + timer.stop("findSplitsBins") + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) + + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) + val baggedInput = if (numTrees > 1) { + BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed) + } else { + BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + }.persist(StorageLevel.MEMORY_AND_DISK) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + + // Max memory usage for aggregates + // TODO: Calculate memory usage more precisely. + val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + val maxMemoryPerNode = { + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. + Some(metadata.numBins.zipWithIndex.sortBy(- _._1) + .take(metadata.numFeaturesPerNode).map(_._2)) + } else { + None + } + RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + } + require(maxMemoryPerNode <= maxMemoryUsage, + s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + + " which is too small for the given features." + + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") + + timer.stop("init") + + /* + * The main idea here is to perform group-wise training of the decision tree nodes thus + * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). + * Each data sample is handled by a particular node (or it reaches a leaf and is not used + * in lower levels). + */ + + // FIFO queue of nodes to train: (treeIndex, node) + val nodeQueue = new mutable.Queue[(Int, Node)]() + + val rng = new scala.util.Random() + rng.setSeed(seed) + + // Allocate and queue root nodes. + val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + + while (nodeQueue.nonEmpty) { + // Collect some nodes to split, and choose features for each node (if subsampling). + // Each group of nodes may come from one or multiple trees, and at multiple levels. + val (nodesForGroup, treeToNodeToIndexInfo) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + // Sanity check (should never occur): + assert(nodesForGroup.size > 0, + s"RandomForest selected empty nodesForGroup. Error for unknown reason.") + + // Choose node splits, and enqueue new nodes as needed. + timer.start("findBestSplits") + DecisionTree.findBestSplits(baggedInput, + metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + timer.stop("findBestSplits") + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) + RandomForestModel.build(trees) + } + +} + +object RandomForest extends Serializable with Logging { + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): RandomForestModel = { + require(strategy.algo == Classification, + s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") + val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) + rf.train(input) + } + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()): RandomForestModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new Strategy(Classification, impurityType, maxDepth, + numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) + trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + trainClassifier(input.rdd, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): RandomForestModel = { + require(strategy.algo == Regression, + s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") + val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) + rf.train(input) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()): RandomForestModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new Strategy(Regression, impurityType, maxDepth, + 0, maxBins, Sort, categoricalFeaturesInfo) + trainRegressor(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * List of supported feature subset sampling strategies. + */ + val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "sqrt", "log2", "onethird") + + private[tree] class NodeIndexInfo( + val nodeIndexInGroup: Int, + val featureSubset: Option[Array[Int]]) extends Serializable + + /** + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeQueue Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ + private[tree] def selectNodesToSplit( + nodeQueue: mutable.Queue[(Int, Node)], + maxMemoryUsage: Long, + metadata: DecisionTreeMetadata, + rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { + // Collect some nodes to split: + // nodesForGroup(treeIndex) = nodes to split + val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]() + val mutableTreeToNodeToIndexInfo = + new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() + var memUsage: Long = 0L + var numNodesInGroup = 0 + while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { + val (treeIndex, node) = nodeQueue.head + // Choose subset of features for node (if subsampling). + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) + Some(rng.shuffle(Range(0, metadata.numFeatures).toList) + .take(metadata.numFeaturesPerNode).toArray) + } else { + None + } + // Check if enough memory remains to add this node to the group. + val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + if (memUsage + nodeMemUsage <= maxMemoryUsage) { + nodeQueue.dequeue() + mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node + mutableTreeToNodeToIndexInfo + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) + } + numNodesInGroup += 1 + memUsage += nodeMemUsage + } + // Convert mutable maps to immutable ones. + val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap + val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + (nodesForGroup, treeToNodeToIndexInfo) + } + + /** + * Get the number of values to be stored for this node in the bin aggregates. + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ + private[tree] def aggregateSizeForNode( + metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]): Long = { + val totalBins = if (featureSubset.nonEmpty) { + featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + } else { + metadata.numBins.map(_.toLong).sum + } + if (metadata.isClassification) { + metadata.numClasses * totalBins + } else { + 3 * totalBins + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala new file mode 100644 index 0000000000000..937c8a2ac5836 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * Internal representation of a datapoint which belongs to several subsamples of the same dataset, + * particularly for bagging (e.g., for random forests). + * + * This holds one instance, as well as an array of weights which represent the (weighted) + * number of times which this instance appears in each subsample. + * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that + * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. + * + * @param datum Data instance + * @param subsampleWeights Weight of this instance in each subsampled dataset. + * + * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted + * dataset support, update. (We store subsampleWeights as Double for this future extension.) + */ +private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) + extends Serializable + +private[tree] object BaggedPoint { + + /** + * Convert an input dataset into its BaggedPoint representation, + * choosing subsample counts for each instance. + * Each subsample has the same number of instances as the original dataset, + * and is created by subsampling with replacement. + * @param input Input dataset. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param seed Random seed. + * @return BaggedPoint dataset representation + */ + def convertToBaggedRDD[Datum]( + input: RDD[Datum], + numSubsamples: Int, + seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // TODO: Support different sampling rates, and sampling without replacement. + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1)) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + subsampleWeights(subsampleIndex) = poisson.nextInt() + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + input.map(datum => new BaggedPoint(datum, Array(1.0))) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 61a94246711bf..d49df7a016375 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.tree.impl +import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.impurity._ /** * DecisionTree statistics aggregator. * This holds a flat array of statistics for a set of (nodes, features, bins) * and helps with indexing. + * This class is abstract to support learning with and without feature subsampling. */ -private[tree] class DTStatsAggregator( - val metadata: DecisionTreeMetadata, - val numNodes: Int) extends Serializable { +private[tree] abstract class DTStatsAggregator( + val metadata: DecisionTreeMetadata) extends Serializable { /** * [[ImpurityAggregator]] instance specifying the impurity type. @@ -43,49 +44,21 @@ private[tree] class DTStatsAggregator( */ val statsSize: Int = impurityAggregator.statsSize - val numFeatures: Int = metadata.numFeatures - - /** - * Number of bins for each feature. This is indexed by the feature index. - */ - val numBins: Array[Int] = metadata.numBins - - /** - * Number of splits for the given feature. - */ - def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex) - /** * Indicator for each feature of whether that feature is an unordered feature. * TODO: Is Array[Boolean] any faster? */ def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - */ - private val featureOffsets: Array[Int] = { - numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. - */ - private val nodeStride: Int = featureOffsets.last - /** * Total number of elements stored in this aggregator. */ - val allStatsSize: Int = numNodes * nodeStride + def allStatsSize: Int /** - * Flat array of elements. - * Index for start of stats for a (node, feature, bin) is: - * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex)) - * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex)) + * Get flat array of elements stored in this aggregator. */ - val allStats: Array[Double] = new Array[Double](allStatsSize) + protected def allStats: Array[Double] /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). @@ -102,36 +75,39 @@ private[tree] class DTStatsAggregator( /** * Update the stats for a given (node, feature, bin) for ordered features, using the given label. */ - def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { - val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label) + def update( + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) } /** * Pre-compute node offset for use with [[nodeUpdate]]. */ - def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + def getNodeOffset(nodeIndex: Int): Int /** * Faster version of [[update]]. * Update the stats for a given (node, feature, bin) for ordered features, using the given label. * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. */ - def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { - val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label) - } + def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit /** * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. * For ordered features only. */ - def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - nodeIndex * nodeStride + featureOffsets(featureIndex) - } + def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int /** * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. @@ -140,9 +116,9 @@ private[tree] class DTStatsAggregator( def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { require(isUnordered(featureIndex), s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") - val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) + s" but was called for ordered feature $featureIndex.") + val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex) + (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize) } /** @@ -154,8 +130,13 @@ private[tree] class DTStatsAggregator( * (node, feature, left/right child) offset from * [[getLeftRightNodeFeatureOffsets]]. */ - def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = { - impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label) + def nodeFeatureUpdate( + nodeFeatureOffset: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label, + instanceWeight) } /** @@ -189,7 +170,139 @@ private[tree] class DTStatsAggregator( } this } +} + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + * + * This instance of [[DTStatsAggregator]] is used when not subsampling features. + * + * @param numNodes Number of nodes to collect statistics for. + */ +private[tree] class DTStatsAggregatorFixedFeatures( + metadata: DecisionTreeMetadata, + numNodes: Int) extends DTStatsAggregator(metadata) { + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + * Mapping: featureIndex --> offset + */ + private val featureOffsets: Array[Int] = { + metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + + /** + * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. + */ + private val nodeStride: Int = featureOffsets.last + override val allStatsSize: Int = numNodes * nodeStride + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats precede the right child stats + * in the binIndex order. + */ + override protected val allStats: Array[Double] = new Array[Double](allStatsSize) + + override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + + override def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + nodeIndex * nodeStride + featureOffsets(featureIndex) + } +} + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + * + * This instance of [[DTStatsAggregator]] is used when subsampling features. + * + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + */ +private[tree] class DTStatsAggregatorSubsampledFeatures( + metadata: DecisionTreeMetadata, + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) { + + /** + * For each node, offset for each feature for calculating indices into the [[allStats]] array. + * Mapping: nodeIndex --> featureIndex --> offset + */ + private val featureOffsets: Array[Array[Int]] = { + val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum + val offsets = new Array[Array[Int]](numNodes) + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) => + nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) => + offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_)) + .scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + } + offsets + } + + /** + * For each node, offset for each feature for calculating indices into the [[allStats]] array. + */ + protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _) + + override val allStatsSize: Int = nodeOffsets.last + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats precede the right child stats + * in the binIndex order. + */ + override protected val allStats: Array[Double] = new Array[Double](allStatsSize) + + override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex) + + /** + * Faster version of [[update]]. + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. + * @param featureIndex Index of feature in featuresForNodes(nodeIndex). + * Note: This is NOT the original feature index. + */ + override def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For ordered features only. + * @param featureIndex Index of feature in featuresForNodes(nodeIndex). + * Note: This is NOT the original feature index. + */ + override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex) + } } private[tree] object DTStatsAggregator extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index b6d49e5555b1a..212dce25236e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -48,7 +48,9 @@ private[tree] class DecisionTreeMetadata( val quantileStrategy: QuantileStrategy, val maxDepth: Int, val minInstancesPerNode: Int, - val minInfoGain: Double) extends Serializable { + val minInfoGain: Double, + val numTrees: Int, + val numFeaturesPerNode: Int) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -73,6 +75,11 @@ private[tree] class DecisionTreeMetadata( numBins(featureIndex) - 1 } + /** + * Indicates if feature subsampling is being used. + */ + def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode + } private[tree] object DecisionTreeMetadata { @@ -82,7 +89,11 @@ private[tree] object DecisionTreeMetadata { * This computes which categorical features will be ordered vs. unordered, * as well as the number of splits and bins for each feature. */ - def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String): DecisionTreeMetadata = { val numFeatures = input.take(1)(0).features.size val numExamples = input.count() @@ -128,13 +139,43 @@ private[tree] object DecisionTreeMetadata { } } + // Set number of features to use per node (for random forests). + val _featureSubsetStrategy = featureSubsetStrategy match { + case "auto" => + if (numTrees == 1) { + "all" + } else { + if (strategy.algo == Classification) { + "sqrt" + } else { + "onethird" + } + } + case _ => featureSubsetStrategy + } + val numFeaturesPerNode: Int = _featureSubsetStrategy match { + case "all" => numFeatures + case "sqrt" => math.sqrt(numFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numFeatures / 3.0).ceil.toInt + } + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain) + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) } /** + * Version of [[buildMetadata()]] for DecisionTree. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy): DecisionTreeMetadata = { + buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") + } + + /** * Given the arity of a categorical feature (arity = number of categories), * return the number of bins for the feature if it is to be treated as an unordered feature. * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 1c8afc2d0f4bc..0e02345aa3774 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -89,12 +89,12 @@ private[tree] class EntropyAggregator(numClasses: Int) * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } - allStats(offset + label.toInt) += 1 + allStats(offset + label.toInt) += instanceWeight } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 5cfdf345d163c..7c83cd48e16a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -85,12 +85,12 @@ private[tree] class GiniAggregator(numClasses: Int) * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } - allStats(offset + label.toInt) += 1 + allStats(offset + label.toInt) += instanceWeight } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 5a047d6cb5480..60e2ab2bb829e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -78,7 +78,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit /** * Get an [[ImpurityCalculator]] for a (node, feature, bin). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index e9ccecb1b8067..df9eafa5da16a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -75,10 +75,10 @@ private[tree] class VarianceAggregator() * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { - allStats(offset) += 1 - allStats(offset + 1) += label - allStats(offset + 2) += label * label + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + allStats(offset) += instanceWeight + allStats(offset + 1) += instanceWeight * label + allStats(offset + 2) += instanceWeight * label * label } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5f0095d23c7ed..56c3e25d9285f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -41,12 +41,12 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - val predict: Double, - val isLeaf: Boolean, - val split: Option[Split], + var predict: Double, + var isLeaf: Boolean, + var split: Option[Split], var leftNode: Option[Node], var rightNode: Option[Node], - val stats: Option[InformationGainStats]) extends Serializable with Logging { + var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + "split = " + split + ", stats = " + stats @@ -167,6 +167,11 @@ class Node ( private[tree] object Node { + /** + * Return a node with the given node id (but nothing else set). + */ + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + /** * Return the index of the left child of this node. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala new file mode 100644 index 0000000000000..538c0e233202a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Random forest model for classification or regression. + * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make + * aggregate predictions. + * @param trees Trees which make up this forest. This cannot be empty. + * @param algo algorithm type -- classification or regression + */ +@Experimental +class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable { + + require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") + + /** + * Predict values for a single data point. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features: Vector): Double = { + algo match { + case Classification => + val predictionToCount = new mutable.HashMap[Int, Int]() + trees.foreach { tree => + val prediction = tree.predict(features).toInt + predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 + } + predictionToCount.maxBy(_._2)._1 + case Regression => + trees.map(_.predict(features)).sum / trees.size + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = { + features.map(x => predict(x)) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Print full model. + */ + override def toString: String = { + val header = algo match { + case Classification => + s"RandomForestModel classifier with $numTrees trees\n" + case Regression => + s"RandomForestModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"RandomForestModel given unknown algo parameter: $algo.") + } + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + +} + +private[tree] object RandomForestModel { + + def build(trees: Array[DecisionTreeModel]): RandomForestModel = { + require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") + val algo: Algo = trees(0).algo + require(trees.forall(_.algo == algo), + "RandomForestModel cannot combine trees which have different output types" + + " (classification/regression).") + new RandomForestModel(trees, algo) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2b2e579b992f6..a48ed71a1c5fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.collection.mutable import org.scalatest.FunSuite @@ -26,39 +27,13 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext class DecisionTreeSuite extends FunSuite with LocalSparkContext { - def validateClassifier( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { - val predictions = input.map(x => model.predict(x.features)) - val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => - prediction != expected.label - } - val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy, - s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") - } - - def validateRegressor( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { - val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") - } - test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -233,7 +208,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - // 2^10 - 1 > 100, so categorical features will be ordered + // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -269,9 +244,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 0) assert(bins(0).length === 0) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode: Node, doneTraining: Boolean) = - DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories === List(1.0)) @@ -299,10 +272,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories.length === 1) @@ -331,7 +301,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - validateRegressor(model, arr, 0.0) + DecisionTreeSuite.validateRegressor(model, arr, 0.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -352,12 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -381,12 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -411,12 +371,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -441,12 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -471,25 +421,39 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, numClassesForClassification = 2, maxBins = 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNodeCopy1 = modelOneNode.topNode.deepCopy() - val rootNodeCopy2 = modelOneNode.topNode.deepCopy() + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) - // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, - rootNodeCopy1, splits, bins, 10) - assert(rootNode.leftNode.nonEmpty) - assert(rootNode.rightNode.nonEmpty) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) val children1 = new Array[Node](2) - children1(0) = rootNode.leftNode.get - children1(1) = rootNode.rightNode.get - - // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second - // level tree construction. - val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, - rootNodeCopy2, splits, bins, 0) - assert(rootNode2.leftNode.nonEmpty) - assert(rootNode2.rightNode.nonEmpty) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) val children2 = new Array[Node](2) children2(0) = rootNode2.leftNode.get children2(1) = rootNode2.rightNode.get @@ -521,10 +485,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.feature === 0) @@ -544,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -561,7 +522,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) assert(model.topNode.split.get.feature === 1) @@ -581,14 +542,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 0) @@ -610,12 +568,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 1) @@ -635,12 +590,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 1) @@ -660,10 +612,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.feature === 0) @@ -682,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(strategy.isMulticlassClassification) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.6) + DecisionTreeSuite.validateClassifier(model, arr, 0.6) } test("split must satisfy min instances per node requirements") { @@ -691,24 +640,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) assert(model.topNode.predict == 0.0) - val predicts = input.map(p => model.predict(p.features)).collect() + val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) } - // test for findBestSplits when no valid split can be found - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + // test when no valid split can be found + val rootNode = model.topNode val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) @@ -723,15 +668,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), numClassesForClassification = 2, minInstancesPerNode = 2) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get val gain = rootNode.stats.get @@ -757,12 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(predict == 0.0) } - // test for findBestSplits when no valid split can be found - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + // test when no valid split can be found + val rootNode = model.topNode val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) @@ -771,6 +709,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { object DecisionTreeSuite { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala new file mode 100644 index 0000000000000..30669fcd1c75b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} +import org.apache.spark.mllib.tree.impurity.{Gini, Variance} +import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.util.StatCounter + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends FunSuite with LocalSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + val numTrees = 1 + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, + featureSubsetStrategy = "auto", seed = 123) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) + + val dt = DecisionTree.train(rdd, strategy) + + RandomForestSuite.validateClassifier(rf, arr, 0.9) + DecisionTreeSuite.validateClassifier(dt, arr, 0.9) + + // Make sure trees are the same. + assert(rfTree.toString == dt.toString) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + val numTrees = 1 + + val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, + featureSubsetStrategy = "auto", seed = 123) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) + + val dt = DecisionTree.train(rdd, strategy) + + RandomForestSuite.validateRegressor(rf, arr, 0.01) + DecisionTreeSuite.validateRegressor(dt, arr, 0.01) + + // Make sure trees are the same. + assert(rfTree.toString == dt.toString) + } + + test("Binary classification with continuous features: subsampling features") { + val numFeatures = 50 + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + // Select feature subset for top nodes. Return true if OK. + def checkFeatureSubsetStrategy( + numTrees: Int, + featureSubsetStrategy: String, + numFeaturesPerNode: Int): Unit = { + val seeds = Array(123, 5354, 230, 349867, 23987) + val maxMemoryUsage: Long = 128 * 1024L * 1024L + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + seeds.foreach { seed => + val failString = s"Failed on test with:" + + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" + val nodeQueue = new mutable.Queue[(Int, Node)]() + val topNodes: Array[Node] = new Array[Node](numTrees) + Range(0, numTrees).foreach { treeIndex => + topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1) + nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) + } + val rng = new scala.util.Random(seed = seed) + val (nodesForGroup: Map[Int, Array[Node]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + + assert(nodesForGroup.size === numTrees, failString) + assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree + if (numFeaturesPerNode == numFeatures) { + // featureSubset values should all be None + assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + failString) + } else { + // Check number of features. + assert(treeToNodeToIndexInfo.values.forall(_.values.forall( + _.featureSubset.get.size === numFeaturesPerNode)), failString) + } + } + } + + checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + } + +} + +object RandomForestSuite { + + /** + * Aggregates all values in data, and tests whether the empirical mean and stddev are within + * epsilon of the expected values. + * @param data Every element of the data should be an i.i.d. sample from some distribution. + */ + def testRandomArrays( + data: Array[Array[Double]], + numCols: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double) { + val values = new mutable.ArrayBuffer[Double]() + data.foreach { row => + assert(row.size == numCols) + values ++= row + } + val stats = new StatCounter(values) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + def validateClassifier( + model: RandomForestModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: RandomForestModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + + def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = { + val numInstances = 1000 + val arr = new Array[LabeledPoint](numInstances) + for (i <- 0 until numInstances) { + val label = if (i < numInstances / 10) { + 0.0 + } else if (i < numInstances / 2) { + 1.0 + } else if (i < numInstances * 0.9) { + 0.0 + } else { + 1.0 + } + val features = Array.fill[Double](numFeatures)(i.toDouble) + arr(i) = new LabeledPoint(label, Vectors.dense(features)) + } + arr + } + +} From 1651cc117d73f0af6ec9f55b0c6c9b2bd565906c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sun, 28 Sep 2014 21:55:09 -0700 Subject: [PATCH 120/315] [EC2] Cleanup Python parens and disk dict Minor fixes: * Remove unnecessary parens (Python style) * Sort `disks_by_instance` dict and remove duplicate `t1.micro` key Author: Nicholas Chammas Closes #2571 from nchammas/ec2-polish and squashes the following commits: 9d203d5 [Nicholas Chammas] paren and dict cleanup --- ec2/spark_ec2.py | 60 ++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 7f2cd7d94de39..5776d0b519309 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -508,7 +508,7 @@ def tag_instance(instance, name): break except: print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if (i == 5): + if i == 5: raise "Error - failed max attempts to add name tag" time.sleep(5) @@ -530,7 +530,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): for res in reservations: active = [i for i in res.instances if is_active(i)] for instance in active: - if (instance.tags.get(u'Name') is None): + if instance.tags.get(u'Name') is None: tag_instance(instance, name) # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() @@ -545,7 +545,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): elif name.startswith(cluster_name + "-slave"): slave_nodes.append(inst) if any((master_nodes, slave_nodes)): - print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) + print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) if master_nodes != [] or not die_on_error: return (master_nodes, slave_nodes) else: @@ -626,43 +626,43 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): def get_num_disks(instance_type): # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html # Updated 2014-6-20 + # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { - "m1.small": 1, - "m1.medium": 1, - "m1.large": 2, - "m1.xlarge": 4, - "t1.micro": 1, "c1.medium": 1, "c1.xlarge": 4, - "m2.xlarge": 1, - "m2.2xlarge": 1, - "m2.4xlarge": 2, + "c3.2xlarge": 2, + "c3.4xlarge": 2, + "c3.8xlarge": 2, + "c3.large": 2, + "c3.xlarge": 2, "cc1.4xlarge": 2, "cc2.8xlarge": 4, "cg1.4xlarge": 2, - "hs1.8xlarge": 24, "cr1.8xlarge": 2, + "g2.2xlarge": 1, "hi1.4xlarge": 2, - "m3.medium": 1, - "m3.large": 1, - "m3.xlarge": 2, - "m3.2xlarge": 2, - "i2.xlarge": 1, + "hs1.8xlarge": 24, "i2.2xlarge": 2, "i2.4xlarge": 4, "i2.8xlarge": 8, - "c3.large": 2, - "c3.xlarge": 2, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "r3.large": 1, - "r3.xlarge": 1, + "i2.xlarge": 1, + "m1.large": 2, + "m1.medium": 1, + "m1.small": 1, + "m1.xlarge": 4, + "m2.2xlarge": 1, + "m2.4xlarge": 2, + "m2.xlarge": 1, + "m3.2xlarge": 2, + "m3.large": 1, + "m3.medium": 1, + "m3.xlarge": 2, "r3.2xlarge": 1, "r3.4xlarge": 1, "r3.8xlarge": 2, - "g2.2xlarge": 1, - "t1.micro": 0 + "r3.large": 1, + "r3.xlarge": 1, + "t1.micro": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -785,7 +785,7 @@ def ssh(host, opts, command): ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), stringify_command(command)]) except subprocess.CalledProcessError as e: - if (tries > 5): + if tries > 5: # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( @@ -820,18 +820,18 @@ def ssh_read(host, opts, command): ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) -def ssh_write(host, opts, command, input): +def ssh_write(host, opts, command, arguments): tries = 0 while True: proc = subprocess.Popen( ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], stdin=subprocess.PIPE) - proc.stdin.write(input) + proc.stdin.write(arguments) proc.stdin.close() status = proc.wait() if status == 0: break - elif (tries > 5): + elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: print >> stderr, \ From 657bdff41a27568a981b3e342ad380fe92aa08a0 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Mon, 29 Sep 2014 01:13:15 -0700 Subject: [PATCH 121/315] [CORE] Bugfix: LogErr format in DAGScheduler.scala Author: Zhang, Liye Closes #2572 from liyezhang556520/DAGLogErr and squashes the following commits: 5be2491 [Zhang, Liye] Bugfix: LogErr format in DAGScheduler.scala --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 70c235dffff70..5a96f52a10cd4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1209,7 +1209,7 @@ class DAGScheduler( .format(job.jobId, stageId)) } else if (jobsForStage.get.size == 1) { if (!stageIdToStage.contains(stageId)) { - logError("Missing Stage for stage with id $stageId") + logError(s"Missing Stage for stage with id $stageId") } else { // This is the only job that uses this stage, so fail the stage if it is running. val stage = stageIdToStage(stageId) From aedd251c54fd130fe6e2f28d7587d39136e7ad1c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 29 Sep 2014 10:45:08 -0700 Subject: [PATCH 122/315] [EC2] Sort long, manually-inputted dictionaries Similar to the work done in #2571, this PR just sorts the remaining manually-inputted dicts in the EC2 script so they are easier to maintain. Author: Nicholas Chammas Closes #2578 from nchammas/ec2-dict-sort and squashes the following commits: f55c692 [Nicholas Chammas] sort long dictionaries --- ec2/spark_ec2.py | 69 ++++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 5776d0b519309..941dfb988b9fb 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -217,8 +217,15 @@ def is_active(instance): # Return correct versions of Spark and Shark, given the supplied Spark version def get_spark_shark_version(opts): spark_shark_map = { - "0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1", - "1.0.0": "1.0.0", "1.0.1": "1.0.1", "1.0.2": "1.0.2", "1.1.0": "1.1.0" + "0.7.3": "0.7.1", + "0.8.0": "0.8.0", + "0.8.1": "0.8.1", + "0.9.0": "0.9.0", + "0.9.1": "0.9.1", + "1.0.0": "1.0.0", + "1.0.1": "1.0.1", + "1.0.2": "1.0.2", + "1.1.0": "1.1.0", } version = opts.spark_version.replace("v", "") if version not in spark_shark_map: @@ -227,49 +234,49 @@ def get_spark_shark_version(opts): return (version, spark_shark_map[version]) -# Attempt to resolve an appropriate AMI given the architecture and -# region of the request. -# Information regarding Amazon Linux AMI instance type was update on 2014-6-20: -# http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ +# Attempt to resolve an appropriate AMI given the architecture and region of the request. +# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ +# Last Updated: 2014-06-20 +# For easy maintainability, please keep this manually-inputted dictionary sorted by key. def get_spark_ami(opts): instance_types = { - "m1.small": "pvm", - "m1.medium": "pvm", - "m1.large": "pvm", - "m1.xlarge": "pvm", - "t1.micro": "pvm", "c1.medium": "pvm", "c1.xlarge": "pvm", - "m2.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", + "c3.2xlarge": "pvm", + "c3.4xlarge": "pvm", + "c3.8xlarge": "pvm", + "c3.large": "pvm", + "c3.xlarge": "pvm", "cc1.4xlarge": "hvm", "cc2.8xlarge": "hvm", "cg1.4xlarge": "hvm", - "hs1.8xlarge": "pvm", - "hi1.4xlarge": "pvm", - "m3.medium": "hvm", - "m3.large": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", "cr1.8xlarge": "hvm", - "i2.xlarge": "hvm", + "hi1.4xlarge": "pvm", + "hs1.8xlarge": "pvm", "i2.2xlarge": "hvm", "i2.4xlarge": "hvm", "i2.8xlarge": "hvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", + "i2.xlarge": "hvm", + "m1.large": "pvm", + "m1.medium": "pvm", + "m1.small": "pvm", + "m1.xlarge": "pvm", + "m2.2xlarge": "pvm", + "m2.4xlarge": "pvm", + "m2.xlarge": "pvm", + "m3.2xlarge": "hvm", + "m3.large": "hvm", + "m3.medium": "hvm", + "m3.xlarge": "hvm", "r3.2xlarge": "hvm", "r3.4xlarge": "hvm", "r3.8xlarge": "hvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", + "t1.micro": "pvm", + "t2.medium": "hvm", "t2.micro": "hvm", "t2.small": "hvm", - "t2.medium": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] @@ -624,8 +631,8 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): - # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Updated 2014-6-20 + # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html + # Last Updated: 2014-06-20 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, From 587a0cd7ed964ebfca2c97924c4f1e363f1fd3cb Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Mon, 29 Sep 2014 11:15:09 -0700 Subject: [PATCH 123/315] [MLlib] [SPARK-2885] DIMSUM: All-pairs similarity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # All-pairs similarity via DIMSUM Compute all pairs of similar vectors using brute force approach, and also DIMSUM sampling approach. Laying down some notation: we are looking for all pairs of similar columns in an m x n RowMatrix whose entries are denoted a_ij, with the i’th row denoted r_i and the j’th column denoted c_j. There is an oversampling parameter labeled ɣ that should be set to 4 log(n)/s to get provably correct results (with high probability), where s is the similarity threshold. The algorithm is stated with a Map and Reduce, with proofs of correctness and efficiency in published papers [1] [2]. The reducer is simply the summation reducer. The mapper is more interesting, and is also the heart of the scheme. As an exercise, you should try to see why in expectation, the map-reduce below outputs cosine similarities. ![dimsumv2](https://cloud.githubusercontent.com/assets/3220351/3807272/d1d9514e-1c62-11e4-9f12-3cfdb1d78b3a.png) [1] Bosagh-Zadeh, Reza and Carlsson, Gunnar (2013), Dimension Independent Matrix Square using MapReduce, arXiv:1304.1467 http://arxiv.org/abs/1304.1467 [2] Bosagh-Zadeh, Reza and Goel, Ashish (2012), Dimension Independent Similarity Computation, arXiv:1206.2082 http://arxiv.org/abs/1206.2082 # Testing Tests for all invocations included. Added L1 and L2 norm computation to MultivariateStatisticalSummary since it was needed. Added tests for both of them. Author: Reza Zadeh Author: Xiangrui Meng Closes #1778 from rezazadeh/dimsumv2 and squashes the following commits: 404c64c [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 4eb71c6 [Reza Zadeh] Add excludes for normL1 and normL2 ee8bd65 [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 976ddd4 [Reza Zadeh] Broadcast colMags. Avoid div by zero. 3467cff [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 aea0247 [Reza Zadeh] Allow large thresholds to promote sparsity 9fe17c0 [Xiangrui Meng] organize imports 2196ba5 [Xiangrui Meng] Merge branch 'rezazadeh-dimsumv2' into dimsumv2 254ca08 [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 f2947e4 [Xiangrui Meng] some optimization 3c4cf41 [Xiangrui Meng] Merge branch 'master' into rezazadeh-dimsumv2 0e4eda4 [Reza Zadeh] Use partition index for RNG 251bb9c [Reza Zadeh] Documentation 25e9d0d [Reza Zadeh] Line length for style fb296f6 [Reza Zadeh] renamed to normL1 and normL2 3764983 [Reza Zadeh] Documentation e9c6791 [Reza Zadeh] New interface and documentation 613f261 [Reza Zadeh] Column magnitude summary 75a0b51 [Reza Zadeh] Use Ints instead of Longs in the shuffle 0f12ade [Reza Zadeh] Style changes eb1dc20 [Reza Zadeh] Use Double.PositiveInfinity instead of Double.Max f56a882 [Reza Zadeh] Remove changes to MultivariateOnlineSummarizer dbc55ba [Reza Zadeh] Make colMagnitudes a method in RowMatrix 41e8ece [Reza Zadeh] style changes 139c8e1 [Reza Zadeh] Syntax changes 029aa9c [Reza Zadeh] javadoc and new test 75edb25 [Reza Zadeh] All tests passing! 05e59b8 [Reza Zadeh] Add test 502ce52 [Reza Zadeh] new interface 654c4fb [Reza Zadeh] default methods 3726ca9 [Reza Zadeh] Remove MatrixAlgebra 6bebabb [Reza Zadeh] remove changes to MatrixSuite 5b8cd7d [Reza Zadeh] Initial files --- .../mllib/linalg/distributed/RowMatrix.scala | 171 +++++++++++++++++- .../stat/MultivariateOnlineSummarizer.scala | 38 +++- .../stat/MultivariateStatisticalSummary.scala | 10 + .../linalg/distributed/RowMatrixSuite.scala | 37 ++++ project/MimaExcludes.scala | 9 +- 5 files changed, 259 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 4174f45d231c7..8380058cf9b41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,17 +19,21 @@ package org.apache.spark.mllib.linalg.distributed import java.util.Arrays -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} -import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} +import scala.collection.mutable.ListBuffer + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, + svd => brzSvd} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.rdd.RDD -import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel /** @@ -411,6 +415,165 @@ class RowMatrix( new RowMatrix(AB, nRows, B.numCols) } + /** + * Compute all cosine similarities between columns of this matrix using the brute-force + * approach of computing normalized dot products. + * + * @return An n x n sparse upper-triangular matrix of cosine similarities between + * columns of this matrix. + */ + def columnSimilarities(): CoordinateMatrix = { + columnSimilarities(0.0) + } + + /** + * Compute similarities between columns of this matrix using a sampling approach. + * + * The threshold parameter is a trade-off knob between estimate quality and computational cost. + * + * Setting a threshold of 0 guarantees deterministic correct results, but comes at exactly + * the same cost as the brute-force approach. Setting the threshold to positive values + * incurs strictly less computational cost than the brute-force approach, however the + * similarities computed will be estimates. + * + * The sampling guarantees relative-error correctness for those pairs of columns that have + * similarity greater than the given similarity threshold. + * + * To describe the guarantee, we set some notation: + * Let A be the smallest in magnitude non-zero element of this matrix. + * Let B be the largest in magnitude non-zero element of this matrix. + * Let L be the maximum number of non-zeros per row. + * + * For example, for {0,1} matrices: A=B=1. + * Another example, for the Netflix matrix: A=1, B=5 + * + * For those column pairs that are above the threshold, + * the computed similarity is correct to within 20% relative error with probability + * at least 1 - (0.981)^10/B^ + * + * The shuffle size is bounded by the *smaller* of the following two expressions: + * + * O(n log(n) L / (threshold * A)) + * O(m L^2^) + * + * The latter is the cost of the brute-force approach, so for non-zero thresholds, + * the cost is always cheaper than the brute-force approach. + * + * @param threshold Set to 0 for deterministic guaranteed correctness. + * Similarities above this threshold are estimated + * with the cost vs estimate quality trade-off described above. + * @return An n x n sparse upper-triangular matrix of cosine similarities + * between columns of this matrix. + */ + def columnSimilarities(threshold: Double): CoordinateMatrix = { + require(threshold >= 0, s"Threshold cannot be negative: $threshold") + + if (threshold > 1) { + logWarning(s"Threshold is greater than 1: $threshold " + + "Computation will be more efficient with promoted sparsity, " + + " however there is no correctness guarantee.") + } + + val gamma = if (threshold < 1e-6) { + Double.PositiveInfinity + } else { + 10 * math.log(numCols()) / threshold + } + + columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) + } + + /** + * Find all similar columns using the DIMSUM sampling algorithm, described in two papers + * + * http://arxiv.org/abs/1206.2082 + * http://arxiv.org/abs/1304.1467 + * + * @param colMags A vector of column magnitudes + * @param gamma The oversampling parameter. For provable results, set to 10 * log(n) / s, + * where s is the smallest similarity score to be estimated, + * and n is the number of columns + * @return An n x n sparse upper-triangular matrix of cosine similarities + * between columns of this matrix. + */ + private[mllib] def columnSimilaritiesDIMSUM( + colMags: Array[Double], + gamma: Double): CoordinateMatrix = { + require(gamma > 1.0, s"Oversampling should be greater than 1: $gamma") + require(colMags.size == this.numCols(), "Number of magnitudes didn't match column dimension") + val sg = math.sqrt(gamma) // sqrt(gamma) used many times + + // Don't divide by zero for those columns with zero magnitude + val colMagsCorrected = colMags.map(x => if (x == 0) 1.0 else x) + + val sc = rows.context + val pBV = sc.broadcast(colMagsCorrected.map(c => sg / c)) + val qBV = sc.broadcast(colMagsCorrected.map(c => math.min(sg, c))) + + val sims = rows.mapPartitionsWithIndex { (indx, iter) => + val p = pBV.value + val q = qBV.value + + val rand = new XORShiftRandom(indx) + val scaled = new Array[Double](p.size) + iter.flatMap { row => + val buf = new ListBuffer[((Int, Int), Double)]() + row match { + case sv: SparseVector => + val nnz = sv.indices.size + var k = 0 + while (k < nnz) { + scaled(k) = sv.values(k) / q(sv.indices(k)) + k += 1 + } + k = 0 + while (k < nnz) { + val i = sv.indices(k) + val iVal = scaled(k) + if (iVal != 0 && rand.nextDouble() < p(i)) { + var l = k + 1 + while (l < nnz) { + val j = sv.indices(l) + val jVal = scaled(l) + if (jVal != 0 && rand.nextDouble() < p(j)) { + buf += (((i, j), iVal * jVal)) + } + l += 1 + } + } + k += 1 + } + case dv: DenseVector => + val n = dv.values.size + var i = 0 + while (i < n) { + scaled(i) = dv.values(i) / q(i) + i += 1 + } + i = 0 + while (i < n) { + val iVal = scaled(i) + if (iVal != 0 && rand.nextDouble() < p(i)) { + var j = i + 1 + while (j < n) { + val jVal = scaled(j) + if (jVal != 0 && rand.nextDouble() < p(j)) { + buf += (((i, j), iVal * jVal)) + } + j += 1 + } + } + i += 1 + } + } + buf + } + }.reduceByKey(_ + _).map { case ((i, j), sim) => + MatrixEntry(i.toLong, j.toLong, sim) + } + new CoordinateMatrix(sims, numCols(), numCols()) + } + private[mllib] override def toBreeze(): BDM[Double] = { val m = numRows().toInt val n = numCols().toInt diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7d845c44365dd..3025d4837cab4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -42,6 +42,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var n = 0 private var currMean: BDV[Double] = _ private var currM2n: BDV[Double] = _ + private var currM2: BDV[Double] = _ + private var currL1: BDV[Double] = _ private var totalCnt: Long = 0 private var nnz: BDV[Double] = _ private var currMax: BDV[Double] = _ @@ -60,6 +62,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMean = BDV.zeros[Double](n) currM2n = BDV.zeros[Double](n) + currM2 = BDV.zeros[Double](n) + currL1 = BDV.zeros[Double](n) nnz = BDV.zeros[Double](n) currMax = BDV.fill(n)(Double.MinValue) currMin = BDV.fill(n)(Double.MaxValue) @@ -81,6 +85,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val tmpPrevMean = currMean(i) currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) + currM2(i) += value * value + currL1(i) += math.abs(value) nnz(i) += 1.0 } @@ -97,7 +103,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @return This MultivariateOnlineSummarizer object. */ def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.totalCnt != 0 && other.totalCnt != 0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt @@ -114,6 +120,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / (nnz(i) + other.nnz(i)) } + // merge m2 together + if (nnz(i) + other.nnz(i) != 0.0) { + currM2(i) += other.currM2(i) + } + // merge l1 together + if (nnz(i) + other.nnz(i) != 0.0) { + currL1(i) += other.currL1(i) + } + if (currMax(i) < other.currMax(i)) { currMax(i) = other.currMax(i) } @@ -127,6 +142,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this.n = other.n this.currMean = other.currMean.copy this.currM2n = other.currM2n.copy + this.currM2 = other.currM2.copy + this.currL1 = other.currL1.copy this.totalCnt = other.totalCnt this.nnz = other.nnz.copy this.currMax = other.currMax.copy @@ -198,4 +215,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } Vectors.fromBreeze(currMin) } + + override def normL2: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + val realMagnitude = BDV.zeros[Double](n) + + var i = 0 + while (i < currM2.size) { + realMagnitude(i) = math.sqrt(currM2(i)) + i += 1 + } + + Vectors.fromBreeze(realMagnitude) + } + + override def normL1: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + Vectors.fromBreeze(currL1) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index f9eb343da2b82..6a364c93284af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -53,4 +53,14 @@ trait MultivariateStatisticalSummary { * Minimum value of each column. */ def min: Vector + + /** + * Euclidean magnitude of each column + */ + def normL2: Vector + + /** + * L1 norm of each column + */ + def normL1: Vector } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 1d3a3221365cc..63f3ed58c0d4d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -95,6 +95,40 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { } } + test("similar columns") { + val colMags = Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)) + val expected = BDM( + (0.0, 54.0, 72.0), + (0.0, 0.0, 78.0), + (0.0, 0.0, 0.0)) + + for (i <- 0 until n; j <- 0 until n) { + expected(i, j) /= (colMags(i) * colMags(j)) + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilarities(0.11).toBreeze() + for (i <- 0 until n; j <- 0 until n) { + if (expected(i, j) > 0) { + val actual = expected(i, j) + val estimate = G(i, j) + assert(math.abs(actual - estimate) / actual < 0.2, + s"Similarities not close enough: $actual vs $estimate") + } + } + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilarities() + assert(closeToZero(G.toBreeze() - expected)) + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilaritiesDIMSUM(colMags.toArray, 150.0) + assert(closeToZero(G.toBreeze() - expected)) + } + } + test("svd of a full-rank matrix") { for (mat <- Seq(denseMat, sparseMat)) { for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { @@ -190,6 +224,9 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") + assert(summary.normL2 === Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)), + "magnitude mismatch.") + assert(summary.normL1 === Vectors.dense(18.0, 12.0, 16.0), "L1 norm mismatch") } } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3280e662fa0b1..1adfaa18c6202 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -39,7 +39,14 @@ object MimaExcludes { MimaBuild.excludeSparkPackage("graphx") ) ++ MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ - MimaBuild.excludeSparkClass("mllib.linalg.Vector") + MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ + Seq( + // Added normL1 and normL2 to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2") + ) case v if v.startsWith("1.1") => Seq( From dab1b0ae29a6d3017bdca23464f22a51d51eaae1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 29 Sep 2014 11:25:32 -0700 Subject: [PATCH 124/315] [SPARK-3032][Shuffle] Fix key comparison integer overflow introduced sorting exception Previous key comparison in `ExternalSorter` will get wrong sorting result or exception when key comparison overflows, details can be seen in [SPARK-3032](https://issues.apache.org/jira/browse/SPARK-3032). Here fix this and add a unit test to prove it. Author: jerryshao Closes #2514 from jerryshao/SPARK-3032 and squashes the following commits: 6f3c302 [jerryshao] Improve the unit test according to comments 01911e6 [jerryshao] Change the test to show the contract violate exception 83acb38 [jerryshao] Minor changes according to comments fa2a08f [jerryshao] Fix key comparison integer overflow introduced sorting exception --- .../util/collection/ExternalSorter.scala | 2 +- .../util/collection/ExternalSorterSuite.scala | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 0a152cb97ad9e..644fa36818647 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -144,7 +144,7 @@ private[spark] class ExternalSorter[K, V, C]( override def compare(a: K, b: K): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() - h1 - h2 + if (h1 < h2) -1 else if (h1 == h2) 0 else 1 } }) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 706faed980f31..f26e40fbd4b36 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -24,6 +24,8 @@ import org.scalatest.{PrivateMethodTester, FunSuite} import org.apache.spark._ import org.apache.spark.SparkContext._ +import scala.util.Random + class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -707,4 +709,57 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) assertDidNotBypassMergeSort(sorter4) } + + test("sort without breaking sorting contracts") { + val conf = createSparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.01") + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + // Using wrongOrdering to show integer overflow introduced exception. + val rand = new Random(100L) + val wrongOrdering = new Ordering[String] { + override def compare(a: String, b: String) = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + h1 - h2 + } + } + + val testData = Array.tabulate(100000) { _ => rand.nextInt().toString } + + val sorter1 = new ExternalSorter[String, String, String]( + None, None, Some(wrongOrdering), None) + val thrown = intercept[IllegalArgumentException] { + sorter1.insertAll(testData.iterator.map(i => (i, i))) + sorter1.iterator + } + + assert(thrown.getClass() === classOf[IllegalArgumentException]) + assert(thrown.getMessage().contains("Comparison method violates its general contract")) + sorter1.stop() + + // Using aggregation and external spill to make sure ExternalSorter using + // partitionKeyComparator. + def createCombiner(i: String) = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String) = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + sorter2.insertAll(testData.iterator.map(i => (i, i))) + + // To validate the hash ordering of key + var minKey = Int.MinValue + sorter2.iterator.foreach { case (k, v) => + val h = k.hashCode() + assert(h >= minKey) + minKey = h + } + + sorter2.stop() + } } From e43c72fe04d4fbf2a108b456d533e641b71b0a2a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 12:38:24 -0700 Subject: [PATCH 125/315] Add more debug message for ManagedBuffer This is to help debug the error reported at http://apache-spark-user-list.1001560.n3.nabble.com/SQL-queries-fail-in-1-2-0-SNAPSHOT-td15327.html Author: Reynold Xin Closes #2580 from rxin/buffer-debug and squashes the following commits: 5814292 [Reynold Xin] Logging close() in case close() fails. 323dfec [Reynold Xin] Add more debug message. --- .../apache/spark/network/ManagedBuffer.scala | 43 ++++++++++++++++--- .../scala/org/apache/spark/util/Utils.scala | 14 ++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index e990c1da6730f..a4409181ec907 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -17,15 +17,17 @@ package org.apache.spark.network -import java.io.{FileInputStream, RandomAccessFile, File, InputStream} +import java.io._ import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode +import scala.util.Try + import com.google.common.io.ByteStreams import io.netty.buffer.{ByteBufInputStream, ByteBuf} -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.{ByteBufferInputStream, Utils} /** @@ -71,18 +73,47 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt try { channel = new RandomAccessFile(file, "r").getChannel channel.map(MapMode.READ_ONLY, offset, length) + } catch { + case e: IOException => + Try(channel.size).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } } finally { if (channel != null) { - channel.close() + Utils.tryLog(channel.close()) } } } override def inputStream(): InputStream = { - val is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) + var is: FileInputStream = null + try { + is = new FileInputStream(file) + is.skip(offset) + ByteStreams.limit(is, length) + } catch { + case e: IOException => + if (is != null) { + Utils.tryLog(is.close()) + } + Try(file.length).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } + case e: Throwable => + if (is != null) { + Utils.tryLog(is.close()) + } + throw e + } } + + override def toString: String = s"${getClass.getName}($file, $offset, $length)" } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2755887feeeff..10d440828e323 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1304,6 +1304,20 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block in a Try, logging any uncaught exceptions. */ + def tryLog[T](f: => T): Try[T] = { + try { + val res = f + scala.util.Success(res) + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + scala.util.Failure(t) + } + } + /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ def isFatalError(e: Throwable): Boolean = { e match { From 0bbe7faeffa17577ae8a33dfcd8c4c783db5c909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?baishuo=28=E7=99=BD=E7=A1=95=29?= Date: Mon, 29 Sep 2014 15:51:55 -0700 Subject: [PATCH 126/315] [SPARK-3007][SQL]Add Dynamic Partition support to Spark Sql hive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit a new PR base on new master. changes are the same as https://github.com/apache/spark/pull/1919 Author: baishuo(白硕) Author: baishuo Author: Cheng Lian Closes #2226 from baishuo/patch-3007 and squashes the following commits: e69ce88 [Cheng Lian] Adds tests to verify dynamic partitioning folder layout b20a3dc [Cheng Lian] Addresses @yhuai's comments 096bbbc [baishuo(白硕)] Merge pull request #1 from liancheng/refactor-dp 1093c20 [Cheng Lian] Adds more tests 5004542 [Cheng Lian] Minor refactoring fae9eff [Cheng Lian] Refactors InsertIntoHiveTable to a Command 528e84c [Cheng Lian] Fixes typo in test name, regenerated golden answer files c464b26 [Cheng Lian] Refactors dynamic partitioning support 5033928 [baishuo] pass check style 2201c75 [baishuo] use HiveConf.DEFAULTPARTITIONNAME to replace hive.exec.default.partition.name b47c9bf [baishuo] modify according micheal's advice c3ab36d [baishuo] modify for some bad indentation 7ce2d9f [baishuo] modify code to pass scala style checks 37c1c43 [baishuo] delete a empty else branch 66e33fc [baishuo] do a little modify 88d0110 [baishuo] update file after test a3961d9 [baishuo(白硕)] Update Cast.scala f7467d0 [baishuo(白硕)] Update InsertIntoHiveTable.scala c1a59dd [baishuo(白硕)] Update Cast.scala 0e18496 [baishuo(白硕)] Update HiveQuerySuite.scala 60f70aa [baishuo(白硕)] Update InsertIntoHiveTable.scala 0a50db9 [baishuo(白硕)] Update HiveCompatibilitySuite.scala 491c7d0 [baishuo(白硕)] Update InsertIntoHiveTable.scala a2374a8 [baishuo(白硕)] Update InsertIntoHiveTable.scala 701a814 [baishuo(白硕)] Update SparkHadoopWriter.scala dc24c41 [baishuo(白硕)] Update HiveQl.scala --- .../execution/HiveCompatibilitySuite.scala | 17 ++ .../org/apache/spark/SparkHadoopWriter.scala | 195 ---------------- .../org/apache/spark/sql/hive/HiveQl.scala | 5 - .../hive/execution/InsertIntoHiveTable.scala | 207 +++++++++-------- .../spark/sql/hive/hiveWriterContainers.scala | 217 ++++++++++++++++++ ...rtition-0-be33aaa7253c8f248ff3921cd7dae340 | 0 ...rtition-1-640552dd462707563fd255a713f83b41 | 0 ...rtition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 | 1 + ...rtition-3-b7f7fa7ebf666f4fee27e149d8c6961f | 0 ...rtition-4-8bdb71ad8cb3cc3026043def2525de3a | 0 ...rtition-5-c630dce438f3792e7fb0f523fbbb3e1e | 0 ...rtition-6-7abc9ec8a36cdc5e89e955265a7fd7cf | 0 ...rtition-7-be33aaa7253c8f248ff3921cd7dae340 | 0 .../sql/hive/execution/HiveQuerySuite.scala | 100 +++++++- 14 files changed, 443 insertions(+), 299 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 556c984ad392b..35e9c9939d4b7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -220,6 +220,23 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { */ override def whiteList = Seq( "add_part_exist", + "dynamic_partition_skip_default", + "infer_bucket_sort_dyn_part", + "load_dyn_part1", + "load_dyn_part2", + "load_dyn_part3", + "load_dyn_part4", + "load_dyn_part5", + "load_dyn_part6", + "load_dyn_part7", + "load_dyn_part8", + "load_dyn_part9", + "load_dyn_part10", + "load_dyn_part11", + "load_dyn_part12", + "load_dyn_part13", + "load_dyn_part14", + "load_dyn_part14_win", "add_part_multiple", "add_partition_no_whitelist", "add_partition_with_whitelist", diff --git a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala deleted file mode 100644 index ab7862f4f9e06..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException -import java.text.NumberFormat -import java.util.Date - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} -import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} -import org.apache.hadoop.hive.ql.plan.FileSinkDesc -import org.apache.hadoop.mapred._ -import org.apache.hadoop.io.Writable - -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} - -/** - * Internal helper class that saves an RDD using a Hive OutputFormat. - * It is based on [[SparkHadoopWriter]]. - */ -private[hive] class SparkHiveHadoopWriter( - @transient jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { - - private val now = new Date() - private val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private var format: HiveOutputFormat[AnyRef, Writable] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null - - def preSetup() { - setIDs(0, 0, 0) - setConfParams() - - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) - } - - - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - setConfParams() - } - - def open() { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val extension = Utilities.getFileExtension( - conf.value, - fileSinkConf.getCompressed, - getOutputFormat()) - - val outputName = "part-" + numfmt.format(splitID) + extension - val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) - - getOutputCommitter().setupTask(getTaskContext()) - writer = HiveFileFormatUtils.getHiveRecordWriter( - conf.value, - fileSinkConf.getTableInfo, - conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], - fileSinkConf, - path, - null) - } - - def write(value: Writable) { - if (writer != null) { - writer.write(value) - } else { - throw new IOException("Writer is null, open() has not been called") - } - } - - def close() { - // Seems the boolean value passed into close does not matter. - writer.close(false) - } - - def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } else { - logWarning ("No need to commit output of task: " + taID.value) - } - } - - def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) - } - - // ********* Private Functions ********* - - private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] - } - format - } - - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter - } - committer - } - - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) - } - jobContext - } - - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) - } - taskContext - } - - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { - jobID = jobId - splitID = splitId - attemptID = attemptId - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -private[hive] object SparkHiveHadoopWriter { - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0aa6292c0184e..4e30e6e06fe21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -837,11 +837,6 @@ private[hive] object HiveQl { cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - if (partitionKeys.values.exists(p => p.isEmpty)) { - throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + - s"dynamic partitioning.") - } - InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) case a: ASTNode => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index a284a91a91e31..3d2ee010696f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import java.util.{HashMap => JHashMap} - import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector -import org.apache.hadoop.io.Writable +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{JavaHiveDecimalObjectInspector, JavaHiveVarcharObjectInspector} import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter} +import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} +import org.apache.spark.sql.hive._ +import org.apache.spark.{SerializableWritable, SparkException, TaskContext} /** * :: DeveloperApi :: @@ -51,7 +49,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode { + extends UnaryNode with Command { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -101,66 +99,74 @@ case class InsertIntoHiveTable( } def saveAsHiveFile( - rdd: RDD[Writable], + rdd: RDD[Row], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: JobConf, - isCompressed: Boolean) { - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - conf.setOutputValueClass(valueClass) - if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { - throw new SparkException("Output format class not set") - } - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) + conf: SerializableWritable[JobConf], + writerContainer: SparkHiveWriterContainer) { + assert(valueClass != null, "Output value class not set") + conf.value.setOutputValueClass(valueClass) + + val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName + assert(outputFileFormatClassName != null, "Output format class not set") + conf.value.set("mapred.output.format.class", outputFileFormatClassName) + + val isCompressed = conf.value.getBoolean( + ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) + if (isCompressed) { // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", // and "mapred.output.compression.type" have no impact on ORC because it uses table properties // to store compression information. - conf.set("mapred.output.compress", "true") + conf.value.set("mapred.output.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(conf.value.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(conf.value.get("mapred.output.compression.type")) } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath( - conf, - SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + conf.value.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf.value, + SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) - writer.preSetup() + writerContainer.driverSideSetup() + sc.sparkContext.runJob(rdd, writeToFile _) + writerContainer.commitJob() + + // Note that this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) - def writeToFile(context: TaskContext, iter: Iterator[Writable]) { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber) - writer.setup(context.stageId, context.partitionId, attemptNumber) - writer.open() + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record) + val writer = writerContainer.getLocalFileWriter(row) + writer.write(serializer.serialize(outputData, standardOI)) } - writer.close() - writer.commit() + writerContainer.close() } - - sc.sparkContext.runJob(rdd, writeToFile _) - writer.commitJob() } - override def execute() = result - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -168,50 +174,57 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - private lazy val result: RDD[Row] = { - val childRdd = child.execute() - assert(childRdd != null) - + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val rdd = childRdd.mapPartitions { iter => - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] + val numDynamicPartitions = partition.values.count(_.isEmpty) + val numStaticPartitions = partition.values.count(_.nonEmpty) + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" + } - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val outputData = new Array[Any](fieldOIs.length) - iter.map { row => - var i = 0 - while (i < row.length) { - // Casts Strings to HiveVarchars when necessary. - outputData(i) = wrap(row(i), fieldOIs(i)) - i += 1 - } + // All partition column names in the format of "//..." + val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull - serializer.serialize(outputData, standardOI) + // Validate partition spec if there exist any dynamic partitions + if (numDynamicPartitions > 0) { + // Report error if dynamic partitioning is not enabled + if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) + } + + // Report error if dynamic partition strict mode is on but no static partition is found + if (numStaticPartitions == 0 && + sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) + } + + // Report error if any static partition appears after a dynamic partition + val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) + isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } - // ORC stores compression information in table properties. While, there are other formats - // (e.g. RCFile) that rely on hadoop configurations to store compression information. val jobConf = new JobConf(sc.hiveconf) - saveAsHiveFile( - rdd, - outputClass, - fileSinkConf, - jobConf, - sc.hiveconf.getBoolean("hive.exec.compress.output", false)) - - // TODO: Handle dynamic partitioning. + val jobConfSer = new SerializableWritable(jobConf) + + val writerContainer = if (numDynamicPartitions > 0) { + val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) + new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + } else { + new SparkHiveWriterContainer(jobConf, fileSinkConf) + } + + saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) + val outputPath = FileOutputFormat.getOutputPath(jobConf) // Have to construct the format of dbname.tablename. val qualifiedTableName = s"${table.databaseName}.${table.tableName}" @@ -220,10 +233,6 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { - val partitionSpec = partition.map { - case (key, Some(value)) => key -> value - case (key, None) => key -> "" // Should not reach here right now. - } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -231,14 +240,26 @@ case class InsertIntoHiveTable( val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false - db.loadPartition( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + if (numDynamicPartitions > 0) { + db.loadDynamicPartitions( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir + ) + } else { + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } } else { db.loadTable( outputPath, @@ -251,6 +272,6 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - sc.sparkContext.makeRDD(Nil, 1) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala new file mode 100644 index 0000000000000..a667188fa53bd --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred._ + +import org.apache.spark.sql.Row +import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +private[hive] class SparkHiveWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + protected val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private lazy val committer = conf.value.getOutputCommitter + @transient private lazy val jobContext = newJobContext(conf.value, jID.value) + @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient private lazy val outputFormat = + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + + def driverSideSetup() { + setIDs(0, 0, 0) + setConfParams() + committer.setupJob(jobContext) + } + + def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) { + setIDs(jobId, splitId, attemptId) + setConfParams() + committer.setupTask(taskContext) + initWriters() + } + + protected def getOutputName: String = { + val numberFormat = NumberFormat.getInstance() + numberFormat.setMinimumIntegerDigits(5) + numberFormat.setGroupingUsed(false) + val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) + "part-" + numberFormat.format(splitID) + extension + } + + def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + commit() + } + + def commitJob() { + committer.commitJob(jobContext) + } + + protected def initWriters() { + // NOTE this method is executed at the executor side. + // For Hive tables without partitions or with only static partitions, only 1 writer is needed. + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + FileOutputFormat.getTaskOutputPath(conf.value, getOutputName), + Reporter.NULL) + } + + protected def commit() { + if (committer.needsTaskCommit(taskContext)) { + try { + committer.commitTask(taskContext) + logInfo (taID + ": Committed") + } catch { + case e: IOException => + logError("Error committing the output of task: " + taID.value, e) + committer.abortTask(taskContext) + throw e + } + } else { + logInfo("No need to commit output of task: " + taID.value) + } + } + + // ********* Private Functions ********* + + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { + jobID = jobId + splitID = splitId + attemptID = attemptId + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +private[hive] object SparkHiveWriterContainer { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } +} + +private[spark] class SparkHiveDynamicPartitionWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc, + dynamicPartColNames: Array[String]) + extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + + private val defaultPartName = jobConf.get( + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + + @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ + + override protected def initWriters(): Unit = { + // NOTE: This method is executed at the executor side. + // Actual writers are created for each dynamic partition on the fly. + writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + } + + override def close(): Unit = { + writers.values.foreach(_.close(false)) + commit() + } + + override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + val dynamicPartPath = dynamicPartColNames + .zip(row.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = String.valueOf(rawVal) + s"/$col=${if (rawVal == null || string.isEmpty) defaultPartName else string}" + } + .mkString + + def newWriter = { + val newFileSinkDesc = new FileSinkDesc( + fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getTableInfo, + fileSinkConf.getCompressed) + newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) + newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) + + val path = { + val outputPath = FileOutputFormat.getOutputPath(conf.value) + assert(outputPath != null, "Undefined job output-path") + val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) + new Path(workPath, getOutputName) + } + + HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + newFileSinkDesc, + path, + Reporter.NULL) + } + + writers.getOrElseUpdate(dynamicPartPath, newWriter) + } +} diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 b/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f b/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a b/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e b/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf b/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2da8a6fac3d99..5d743a51b47c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.SparkException import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -380,7 +383,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.exists(_ == "== Physical Plan ==") + explanation.contains("== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -568,6 +571,91 @@ class HiveQuerySuite extends HiveComparisonTest { case class LogEntry(filename: String, message: String) case class LogFile(name: String) + createQueryTest("dynamic_partition", + """ + |DROP TABLE IF EXISTS dynamic_part_table; + |CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT); + | + |SET hive.exec.dynamic.partition.mode=nonstrict; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, NULL FROM src WHERE key=150; + | + |INSERT INTO TABLe dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, NULL FROM src WHERE key=150; + | + |DROP TABLE IF EXISTS dynamic_part_table; + """.stripMargin) + + test("Dynamic partition folder layout") { + sql("DROP TABLE IF EXISTS dynamic_part_table") + sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + val data = Map( + Seq("1", "1") -> 1, + Seq("1", "NULL") -> 2, + Seq("NULL", "1") -> 3, + Seq("NULL", "NULL") -> 4) + + data.foreach { case (parts, value) => + sql( + s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 + """.stripMargin) + + val partFolder = Seq("partcol1", "partcol2") + .zip(parts) + .map { case (k, v) => + if (v == "NULL") { + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + } else { + s"$k=$v" + } + } + .mkString("/") + + // Loads partition data to a temporary table to verify contents + val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" + + sql("DROP TABLE IF EXISTS dp_verify") + sql("CREATE TABLE dp_verify(intcol INT)") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") + + assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + } + } + + test("Partition spec validation") { + sql("DROP TABLE IF EXISTS dp_test") + sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") + sql("SET hive.exec.dynamic.partition.mode=strict") + + // Should throw when using strict dynamic partition mode without any static partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + // Should throw when a static partition appears after a dynamic partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + } + test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") @@ -625,27 +713,27 @@ class HiveQuerySuite extends HiveComparisonTest { assert(sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + collectResults(sql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + collectResults(sql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + collectResults(sql("SET")) } // "set key" assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + collectResults(sql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + collectResults(sql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as sql() by repeating the above using sql(). From 51229ff7f4d3517706a1cdc1a2943ede1c605089 Mon Sep 17 00:00:00 2001 From: yingjieMiao Date: Mon, 29 Sep 2014 18:01:27 -0700 Subject: [PATCH 127/315] [graphX] GraphOps: random pick vertex bug When `numVertices > 50`, probability is set to 0. This would cause infinite loop. Author: yingjieMiao Closes #2553 from yingjieMiao/graphx and squashes the following commits: 6adf3c8 [yingjieMiao] [graphX] GraphOps: random pick vertex bug --- graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 02afaa987d40d..d0dd45dba618e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -254,7 +254,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Picks a random vertex from the graph and returns its ID. */ def pickRandomVertex(): VertexId = { - val probability = 50 / graph.numVertices + val probability = 50.0 / graph.numVertices var found = false var retVal: VertexId = null.asInstanceOf[VertexId] while (!found) { From dc30e4504abcda1774f5f09a08bba73d29a2898b Mon Sep 17 00:00:00 2001 From: oded Date: Mon, 29 Sep 2014 18:05:53 -0700 Subject: [PATCH 128/315] Fixed the condition in StronglyConnectedComponents Issue: SPARK-3635 Author: oded Closes #2486 from odedz/master and squashes the following commits: dd7890a [oded] Fixed the condition in StronglyConnectedComponents Issue: SPARK-3635 --- .../apache/spark/graphx/lib/StronglyConnectedComponents.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala index 46da38eeb725a..8dd958033b338 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala @@ -75,7 +75,7 @@ object StronglyConnectedComponents { sccWorkGraph, Long.MaxValue, activeDirection = EdgeDirection.Out)( (vid, myScc, neighborScc) => (math.min(myScc._1, neighborScc), myScc._2), e => { - if (e.srcId < e.dstId) { + if (e.srcAttr._1 < e.dstAttr._1) { Iterator((e.dstId, e.srcAttr._1)) } else { Iterator() From 210404a56197ad347f1e621ed53ef01327fba2bd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 21:53:21 -0700 Subject: [PATCH 129/315] Minor cleanup of code. Author: Reynold Xin Closes #2581 from rxin/minor-cleanup and squashes the following commits: 736a91b [Reynold Xin] Minor cleanup of code. --- .../apache/spark/scheduler/JobLogger.scala | 17 +----- .../org/apache/spark/util/JsonProtocol.scala | 1 - .../scala/org/apache/spark/util/Utils.scala | 60 +++++++++---------- 3 files changed, 31 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index ceb434feb6ca1..54904bffdf10b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -20,15 +20,12 @@ package org.apache.spark.scheduler import java.io.{File, FileNotFoundException, IOException, PrintWriter} import java.text.SimpleDateFormat import java.util.{Date, Properties} -import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.{DataReadMethod, TaskMetrics} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel +import org.apache.spark.executor.TaskMetrics /** * :: DeveloperApi :: @@ -62,24 +59,16 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent] createLogDir() - // The following 5 functions are used only in testing. - private[scheduler] def getLogDir = logDir - private[scheduler] def getJobIdToPrintWriter = jobIdToPrintWriter - private[scheduler] def getStageIdToJobId = stageIdToJobId - private[scheduler] def getJobIdToStageIds = jobIdToStageIds - private[scheduler] def getEventQueue = eventQueue - /** Create a folder for log files, the folder's name is the creation time of jobLogger */ protected def createLogDir() { val dir = new File(logDir + "/" + logDirName + "/") if (dir.exists()) { return } - if (dir.mkdirs() == false) { + if (!dir.mkdirs()) { // JobLogger should throw a exception rather than continue to construct this object. throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/") } @@ -261,7 +250,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener protected def recordJobProperties(jobId: Int, properties: Properties) { if (properties != null) { val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobId, description, false) + jobLogInfo(jobId, description, withTime = false) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 6a48f673c4e78..5b2e7d3a7edb9 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -25,7 +25,6 @@ import scala.collection.Map import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleReadMetrics, diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 10d440828e323..dbe0cfa2b8ff9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -23,8 +23,6 @@ import java.nio.ByteBuffer import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} -import org.apache.log4j.PropertyConfigurator - import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer @@ -37,12 +35,12 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration +import org.apache.log4j.PropertyConfigurator import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} @@ -86,7 +84,7 @@ private[spark] object Utils extends Logging { ois.readObject.asInstanceOf[T] } - /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */ + /** Deserialize a Long value (used for [[org.apache.spark.api.python.PythonPartitioner]]) */ def deserializeLongValue(bytes: Array[Byte]) : Long = { // Note: we assume that we are given a Long value encoded in network (big-endian) byte order var result = bytes(7) & 0xFFL @@ -153,7 +151,7 @@ private[spark] object Utils extends Logging { def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) /** - * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { if (bb.hasArray) { @@ -333,7 +331,7 @@ private[spark] object Utils extends Logging { val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) - val fileOverwrite = conf.getBoolean("spark.files.overwrite", false) + val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) uri.getScheme match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) @@ -355,7 +353,7 @@ private[spark] object Utils extends Logging { uc.connect() val in = uc.getInputStream() val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) + Utils.copyStream(in, out, closeStreams = true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { if (fileOverwrite) { targetFile.delete() @@ -402,7 +400,7 @@ private[spark] object Utils extends Logging { val fs = getHadoopFileSystem(uri, hadoopConf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) + Utils.copyStream(in, out, closeStreams = true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { if (fileOverwrite) { targetFile.delete() @@ -666,7 +664,7 @@ private[spark] object Utils extends Logging { */ def deleteRecursively(file: File) { if (file != null) { - if ((file.isDirectory) && !isSymlink(file)) { + if (file.isDirectory() && !isSymlink(file)) { for (child <- listFilesSafely(file)) { deleteRecursively(child) } @@ -701,11 +699,7 @@ private[spark] object Utils extends Logging { new File(file.getParentFile().getCanonicalFile(), file.getName()) } - if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) { - return false - } else { - return true - } + !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()) } /** @@ -804,7 +798,7 @@ private[spark] object Utils extends Logging { .start() new Thread("read stdout for " + command(0)) { override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { + for (line <- Source.fromInputStream(process.getInputStream).getLines()) { System.err.println(line) } } @@ -818,8 +812,10 @@ private[spark] object Utils extends Logging { /** * Execute a command and get its output, throwing an exception if it yields a code other than 0. */ - def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { + def executeAndGetOutput( + command: Seq[String], + workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty): String = { val builder = new ProcessBuilder(command: _*) .directory(workingDir) val environment = builder.environment() @@ -829,7 +825,7 @@ private[spark] object Utils extends Logging { val process = builder.start() new Thread("read stderr for " + command(0)) { override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines) { + for (line <- Source.fromInputStream(process.getErrorStream).getLines()) { System.err.println(line) } } @@ -837,7 +833,7 @@ private[spark] object Utils extends Logging { val output = new StringBuffer val stdoutThread = new Thread("read stdout for " + command(0)) { override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { + for (line <- Source.fromInputStream(process.getInputStream).getLines()) { output.append(line) } } @@ -846,8 +842,8 @@ private[spark] object Utils extends Logging { val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output if (exitCode != 0) { - logError(s"Process $command exited with code $exitCode: ${output}") - throw new SparkException("Process " + command + " exited with code " + exitCode) + logError(s"Process $command exited with code $exitCode: $output") + throw new SparkException(s"Process $command exited with code $exitCode") } output.toString } @@ -860,6 +856,7 @@ private[spark] object Utils extends Logging { try { block } catch { + case e: ControlThrowable => throw e case t: Throwable => ExecutorUncaughtExceptionHandler.uncaughtException(t) } } @@ -884,13 +881,12 @@ private[spark] object Utils extends Logging { * @param skipClass Function that is used to exclude non-user-code classes. */ def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { - val trace = Thread.currentThread.getStackTrace() - .filterNot { ste:StackTraceElement => - // When running under some profilers, the current stack trace might contain some bogus - // frames. This is intended to ensure that we don't crash in these situations by - // ignoring any frames that we can't examine. - (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")) - } + val trace = Thread.currentThread.getStackTrace().filterNot { ste: StackTraceElement => + // When running under some profilers, the current stack trace might contain some bogus + // frames. This is intended to ensure that we don't crash in these situations by + // ignoring any frames that we can't examine. + ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace") + } // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD @@ -924,7 +920,7 @@ private[spark] object Utils extends Logging { } val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt CallSite( - shortForm = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), + shortForm = s"$lastSparkMethod at $firstUserFile:$firstUserLine", longForm = callStack.take(callStackDepth).mkString("\n")) } @@ -1027,7 +1023,7 @@ private[spark] object Utils extends Logging { false } - def isSpace(c: Char): Boolean = { + private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -1179,7 +1175,7 @@ private[spark] object Utils extends Logging { } import scala.sys.process._ (linkCmd + src.getAbsolutePath() + " " + dst.getPath() + cmdSuffix) lines_! - ProcessLogger(line => (logInfo(line))) + ProcessLogger(line => logInfo(line)) } @@ -1260,7 +1256,7 @@ private[spark] object Utils extends Logging { val startTime = System.currentTimeMillis while (!terminated) { try { - process.exitValue + process.exitValue() terminated = true } catch { case e: IllegalThreadStateException => From 6b79bfb42580b6bd4c4cd99fb521534a94150693 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 22:56:22 -0700 Subject: [PATCH 130/315] [SPARK-3613] Record only average block size in MapStatus for large stages This changes the way we send MapStatus from executors back to driver for large stages (>2000 tasks). For large stages, we no longer send one byte per block. Instead, we just send the average block size. This makes large jobs (tens of thousands of tasks) much more reliable since the driver no longer sends huge amount of data. Author: Reynold Xin Closes #2470 from rxin/mapstatus and squashes the following commits: 822ff54 [Reynold Xin] Code review feedback. 3b86f56 [Reynold Xin] Added MimaExclude. f89d182 [Reynold Xin] Fixed a bug in MapStatus 6a0401c [Reynold Xin] [SPARK-3613] Record only average block size in MapStatus for large stages. --- .../org/apache/spark/MapOutputTracker.scala | 29 +---- .../apache/spark/scheduler/MapStatus.scala | 119 ++++++++++++++++-- .../shuffle/hash/HashShuffleWriter.scala | 8 +- .../shuffle/sort/SortShuffleWriter.scala | 3 +- .../apache/spark/MapOutputTrackerSuite.scala | 66 ++++------ .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/MapStatusSuite.scala | 92 ++++++++++++++ .../apache/spark/util/AkkaUtilsSuite.scala | 14 +-- project/MimaExcludes.scala | 5 +- 9 files changed, 240 insertions(+), 98 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index f92189b707fb5..4cb0bd4142435 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -349,7 +349,6 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } private[spark] object MapOutputTracker { - private val LOG_BASE = 1.1 // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will @@ -385,34 +384,8 @@ private[spark] object MapOutputTracker { throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) } else { - (status.location, decompressSize(status.compressedSizes(reduceId))) + (status.location, status.getSizeForBlock(reduceId)) } } } - - /** - * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. - * We do this by encoding the log base 1.1 of the size as an integer, which can support - * sizes up to 35 GB with at most 10% error. - */ - def compressSize(size: Long): Byte = { - if (size == 0) { - 0 - } else if (size <= 1L) { - 1 - } else { - math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte - } - } - - /** - * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. - */ - def decompressSize(compressedSize: Byte): Long = { - if (compressedSize == 0) { - 0 - } else { - math.pow(LOG_BASE, compressedSize & 0xFF).toLong - } - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index d3f63ff92ac6f..e25096ea92d70 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,22 +24,123 @@ import org.apache.spark.storage.BlockManagerId /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. - * The map output sizes are compressed using MapOutputTracker.compressSize. */ -private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) - extends Externalizable { +private[spark] sealed trait MapStatus { + /** Location where this task was run. */ + def location: BlockManagerId - def this() = this(null, null) // For deserialization only + /** Estimated size for the reduce block, in bytes. */ + def getSizeForBlock(reduceId: Int): Long +} + + +private[spark] object MapStatus { + + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + if (uncompressedSizes.length > 2000) { + new HighlyCompressedMapStatus(loc, uncompressedSizes) + } else { + new CompressedMapStatus(loc, uncompressedSizes) + } + } + + private[this] val LOG_BASE = 1.1 + + /** + * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. + * We do this by encoding the log base 1.1 of the size as an integer, which can support + * sizes up to 35 GB with at most 10% error. + */ + def compressSize(size: Long): Byte = { + if (size == 0) { + 0 + } else if (size <= 1L) { + 1 + } else { + math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte + } + } + + /** + * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. + */ + def decompressSize(compressedSize: Byte): Long = { + if (compressedSize == 0) { + 0 + } else { + math.pow(LOG_BASE, compressedSize & 0xFF).toLong + } + } +} + + +/** + * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is + * represented using a single byte. + * + * @param loc location where the task is being executed. + * @param compressedSizes size of the blocks, indexed by reduce partition id. + */ +private[spark] class CompressedMapStatus( + private[this] var loc: BlockManagerId, + private[this] var compressedSizes: Array[Byte]) + extends MapStatus with Externalizable { + + protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize)) + } - def writeExternal(out: ObjectOutput) { - location.writeExternal(out) + override def location: BlockManagerId = loc + + override def getSizeForBlock(reduceId: Int): Long = { + MapStatus.decompressSize(compressedSizes(reduceId)) + } + + override def writeExternal(out: ObjectOutput): Unit = { + loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } - def readExternal(in: ObjectInput) { - location = BlockManagerId(in) - compressedSizes = new Array[Byte](in.readInt()) + override def readExternal(in: ObjectInput): Unit = { + loc = BlockManagerId(in) + val len = in.readInt() + compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) } } + + +/** + * A [[MapStatus]] implementation that only stores the average size of the blocks. + * + * @param loc location where the task is being executed. + * @param avgSize average size of all the blocks + */ +private[spark] class HighlyCompressedMapStatus( + private[this] var loc: BlockManagerId, + private[this] var avgSize: Long) + extends MapStatus with Externalizable { + + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.sum / uncompressedSizes.length) + } + + protected def this() = this(null, 0L) // For deserialization only + + override def location: BlockManagerId = loc + + override def getSizeForBlock(reduceId: Int): Long = avgSize + + override def writeExternal(out: ObjectOutput): Unit = { + loc.writeExternal(out) + out.writeLong(avgSize) + } + + override def readExternal(in: ObjectInput): Unit = { + loc = BlockManagerId(in) + avgSize = in.readLong() + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 4b9454d75abb7..746ed33b54c00 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -103,13 +103,11 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => writer.commitAndClose() - val size = writer.fileSegment().length - MapOutputTracker.compressSize(size) + writer.fileSegment().length } - - new MapStatus(blockManager.blockManagerId, compressedSizes) + MapStatus(blockManager.blockManagerId, sizes) } private def revertWrites(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 89a78d6982ba0..927481b72cf4f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,8 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = new MapStatus(blockManager.blockManagerId, - partitionLengths.map(MapOutputTracker.compressSize)) + mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 5369169811f81..1fef79ad1001f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -23,32 +23,13 @@ import akka.actor._ import akka.testkit.TestActorRef import org.scalatest.FunSuite -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { private val conf = new SparkConf - test("compressSize") { - assert(MapOutputTracker.compressSize(0L) === 0) - assert(MapOutputTracker.compressSize(1L) === 1) - assert(MapOutputTracker.compressSize(2L) === 8) - assert(MapOutputTracker.compressSize(10L) === 25) - assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145) - assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218) - // This last size is bigger than we can encode in a byte, so check that we just return 255 - assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255) - } - - test("decompressSize") { - assert(MapOutputTracker.decompressSize(0) === 0) - for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { - val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size)) - assert(size2 >= 0.99 * size && size2 <= 1.11 * size, - "size " + size + " decompressed to " + size2 + ", which is out of range") - } - } test("master start and stop") { val actorSystem = ActorSystem("test") @@ -65,14 +46,12 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(1000L, 10000L))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(10000L, 1000L))) val statuses = tracker.getServerStatuses(10, 0) assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) @@ -84,11 +63,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + val compressedSize1000 = MapStatus.compressSize(1000L) + val compressedSize10000 = MapStatus.compressSize(10000L) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getServerStatuses(10, 0).nonEmpty) @@ -103,11 +82,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + val compressedSize1000 = MapStatus.compressSize(1000L) + val compressedSize10000 = MapStatus.compressSize(10000L) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simultaneous fetch failures @@ -142,10 +121,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, MapStatus( + BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getServerStatuses(10, 0).toSeq === @@ -173,8 +151,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Byte](10)(0))) + masterTracker.registerMapOutput(10, 0, MapStatus( + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) masterActor.receive(GetMapOutputStatuses(10)) } @@ -194,8 +172,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // being sent. masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => - masterTracker.registerMapOutput(20, i, new MapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Byte](4000000)(0))) + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index aa73469b6acd8..a2e4f712db55b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -740,7 +740,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } private def makeMapStatus(host: String, reduces: Int): MapStatus = - new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(2)) private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala new file mode 100644 index 0000000000000..79e04f046e4c4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer + + +class MapStatusSuite extends FunSuite { + + test("compressSize") { + assert(MapStatus.compressSize(0L) === 0) + assert(MapStatus.compressSize(1L) === 1) + assert(MapStatus.compressSize(2L) === 8) + assert(MapStatus.compressSize(10L) === 25) + assert((MapStatus.compressSize(1000000L) & 0xFF) === 145) + assert((MapStatus.compressSize(1000000000L) & 0xFF) === 218) + // This last size is bigger than we can encode in a byte, so check that we just return 255 + assert((MapStatus.compressSize(1000000000000000000L) & 0xFF) === 255) + } + + test("decompressSize") { + assert(MapStatus.decompressSize(0) === 0) + for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { + val size2 = MapStatus.decompressSize(MapStatus.compressSize(size)) + assert(size2 >= 0.99 * size && size2 <= 1.11 * size, + "size " + size + " decompressed to " + size2 + ", which is out of range") + } + } + + test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { + val sizes = Array.fill[Long](2001)(150L) + val status = MapStatus(null, sizes) + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + assert(status.getSizeForBlock(10) === 150L) + assert(status.getSizeForBlock(50) === 150L) + assert(status.getSizeForBlock(99) === 150L) + assert(status.getSizeForBlock(2000) === 150L) + } + + test(classOf[HighlyCompressedMapStatus].getName + ": estimated size is within 10%") { + val sizes = Array.tabulate[Long](50) { i => i.toLong } + val loc = BlockManagerId("a", "b", 10) + val status = MapStatus(loc, sizes) + val ser = new JavaSerializer(new SparkConf) + val buf = ser.newInstance().serialize(status) + val status1 = ser.newInstance().deserialize[MapStatus](buf) + assert(status1.location == loc) + for (i <- 0 until sizes.length) { + // make sure the estimated size is within 10% of the input; note that we skip the very small + // sizes because the compression is very lossy there. + val estimate = status1.getSizeForBlock(i) + if (estimate > 100) { + assert(math.abs(estimate - sizes(i)) * 10 <= sizes(i), + s"incorrect estimated size $estimate, original was ${sizes(i)}") + } + } + } + + test(classOf[HighlyCompressedMapStatus].getName + ": estimated size should be the average size") { + val sizes = Array.tabulate[Long](3000) { i => i.toLong } + val avg = sizes.sum / sizes.length + val loc = BlockManagerId("a", "b", 10) + val status = MapStatus(loc, sizes) + val ser = new JavaSerializer(new SparkConf) + val buf = ser.newInstance().serialize(status) + val status1 = ser.newInstance().deserialize[MapStatus](buf) + assert(status1.location == loc) + for (i <- 0 until 3000) { + val estimate = status1.getSizeForBlock(i) + assert(estimate === avg) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 76bf4cfd11267..7bca1711ae226 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -106,10 +106,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, + MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) @@ -157,10 +156,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, MapStatus( + BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1adfaa18c6202..4076ebc6fc8d5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2") + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), + // MapStatus should be private[spark] + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.MapStatus") ) case v if v.startsWith("1.1") => From de700d31778eb68807183cf32be8034abdc0120e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 23:17:53 -0700 Subject: [PATCH 131/315] [SPARK-3709] Executors don't always report broadcast block removal properly back to the driver The problem was that the 2nd argument in RemoveBroadcast is not tellMaster! It is "removeFromDriver". Basically when removeFromDriver is not true, we don't report broadcast block removal back to the driver, and then other executors mistakenly think that the executor would still have the block, and try to fetch from it. cc @tdas Author: Reynold Xin Closes #2588 from rxin/debug and squashes the following commits: 6dab2e3 [Reynold Xin] Don't log random messages. f430686 [Reynold Xin] Always report broadcast removal back to master. 2a13f70 [Reynold Xin] iii --- .../apache/spark/network/nio/NioBlockTransferService.scala | 2 +- .../org/apache/spark/storage/BlockManagerSlaveActor.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 59958ee894230..b389b9a2022c6 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -200,6 +200,6 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val buffer = blockDataManager.getBlockData(blockId).orNull logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - buffer.nioByteBuffer() + if (buffer == null) null else buffer.nioByteBuffer() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 14ae2f38c5670..8462871e798a5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -58,9 +58,9 @@ class BlockManagerSlaveActor( SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } - case RemoveBroadcast(broadcastId, tellMaster) => + case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, sender) { - blockManager.removeBroadcast(broadcastId, tellMaster) + blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => From b167a8c7e75d9e816784bd655bce1feb6c447210 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Sep 2014 23:36:10 -0700 Subject: [PATCH 132/315] [SPARK-3734] DriverRunner should not read SPARK_HOME from submitter's environment When using spark-submit in `cluster` mode to submit a job to a Spark Standalone cluster, if the JAVA_HOME environment variable was set on the submitting machine then DriverRunner would attempt to use the submitter's JAVA_HOME to launch the driver process (instead of the worker's JAVA_HOME), causing the driver to fail unless the submitter and worker had the same Java location. This commit fixes this by reading JAVA_HOME from sys.env instead of command.environment. Author: Josh Rosen Closes #2586 from JoshRosen/SPARK-3734 and squashes the following commits: e9513d9 [Josh Rosen] [SPARK-3734] DriverRunner should not read SPARK_HOME from submitter's environment. --- .../scala/org/apache/spark/deploy/worker/CommandUtils.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 12e98fd40d6c9..2e9be2a180c68 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.Utils private[spark] object CommandUtils extends Logging { def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { - val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java") + val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows @@ -38,9 +38,6 @@ object CommandUtils extends Logging { command.arguments } - private def getEnv(key: String, command: Command): Option[String] = - command.environment.get(key).orElse(Option(System.getenv(key))) - /** * Attention: this must always be aligned with the environment variables in the run scripts and * the way the JAVA_OPTS are assembled there. From b64fcbd2dcec3418397328399c58f98d990a54f1 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 30 Sep 2014 09:43:46 -0700 Subject: [PATCH 133/315] Revert "[SPARK-3007][SQL]Add Dynamic Partition support to Spark Sql hive" This reverts commit 0bbe7faeffa17577ae8a33dfcd8c4c783db5c909. --- .../execution/HiveCompatibilitySuite.scala | 17 -- .../org/apache/spark/SparkHadoopWriter.scala | 195 ++++++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 5 + .../hive/execution/InsertIntoHiveTable.scala | 207 ++++++++--------- .../spark/sql/hive/hiveWriterContainers.scala | 217 ------------------ ...rtition-0-be33aaa7253c8f248ff3921cd7dae340 | 0 ...rtition-1-640552dd462707563fd255a713f83b41 | 0 ...rtition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 | 1 - ...rtition-3-b7f7fa7ebf666f4fee27e149d8c6961f | 0 ...rtition-4-8bdb71ad8cb3cc3026043def2525de3a | 0 ...rtition-5-c630dce438f3792e7fb0f523fbbb3e1e | 0 ...rtition-6-7abc9ec8a36cdc5e89e955265a7fd7cf | 0 ...rtition-7-be33aaa7253c8f248ff3921cd7dae340 | 0 .../sql/hive/execution/HiveQuerySuite.scala | 100 +------- 14 files changed, 299 insertions(+), 443 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf delete mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 35e9c9939d4b7..556c984ad392b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -220,23 +220,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { */ override def whiteList = Seq( "add_part_exist", - "dynamic_partition_skip_default", - "infer_bucket_sort_dyn_part", - "load_dyn_part1", - "load_dyn_part2", - "load_dyn_part3", - "load_dyn_part4", - "load_dyn_part5", - "load_dyn_part6", - "load_dyn_part7", - "load_dyn_part8", - "load_dyn_part9", - "load_dyn_part10", - "load_dyn_part11", - "load_dyn_part12", - "load_dyn_part13", - "load_dyn_part14", - "load_dyn_part14_win", "add_part_multiple", "add_partition_no_whitelist", "add_partition_with_whitelist", diff --git a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala new file mode 100644 index 0000000000000..ab7862f4f9e06 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc +import org.apache.hadoop.mapred._ +import org.apache.hadoop.io.Writable + +import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +private[hive] class SparkHiveHadoopWriter( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + private val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private var format: HiveOutputFormat[AnyRef, Writable] = null + @transient private var committer: OutputCommitter = null + @transient private var jobContext: JobContext = null + @transient private var taskContext: TaskAttemptContext = null + + def preSetup() { + setIDs(0, 0, 0) + setConfParams() + + val jCtxt = getJobContext() + getOutputCommitter().setupJob(jCtxt) + } + + + def setup(jobid: Int, splitid: Int, attemptid: Int) { + setIDs(jobid, splitid, attemptid) + setConfParams() + } + + def open() { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + + val extension = Utilities.getFileExtension( + conf.value, + fileSinkConf.getCompressed, + getOutputFormat()) + + val outputName = "part-" + numfmt.format(splitID) + extension + val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) + + getOutputCommitter().setupTask(getTaskContext()) + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + path, + null) + } + + def write(value: Writable) { + if (writer != null) { + writer.write(value) + } else { + throw new IOException("Writer is null, open() has not been called") + } + } + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + } + + def commit() { + val taCtxt = getTaskContext() + val cmtr = getOutputCommitter() + if (cmtr.needsTaskCommit(taCtxt)) { + try { + cmtr.commitTask(taCtxt) + logInfo (taID + ": Committed") + } catch { + case e: IOException => + logError("Error committing the output of task: " + taID.value, e) + cmtr.abortTask(taCtxt) + throw e + } + } else { + logWarning ("No need to commit output of task: " + taID.value) + } + } + + def commitJob() { + // always ? Or if cmtr.needsTaskCommit ? + val cmtr = getOutputCommitter() + cmtr.commitJob(getJobContext()) + } + + // ********* Private Functions ********* + + private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { + if (format == null) { + format = conf.value.getOutputFormat() + .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + } + format + } + + private def getOutputCommitter(): OutputCommitter = { + if (committer == null) { + committer = conf.value.getOutputCommitter + } + committer + } + + private def getJobContext(): JobContext = { + if (jobContext == null) { + jobContext = newJobContext(conf.value, jID.value) + } + jobContext + } + + private def getTaskContext(): TaskAttemptContext = { + if (taskContext == null) { + taskContext = newTaskAttemptContext(conf.value, taID.value) + } + taskContext + } + + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { + jobID = jobId + splitID = splitId + attemptID = attemptId + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +private[hive] object SparkHiveHadoopWriter { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 4e30e6e06fe21..0aa6292c0184e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -837,6 +837,11 @@ private[hive] object HiveQl { cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) + if (partitionKeys.values.exists(p => p.isEmpty)) { + throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + + s"dynamic partitioning.") + } + InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) case a: ASTNode => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3d2ee010696f6..a284a91a91e31 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,25 +19,27 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ +import java.util.{HashMap => JHashMap} + import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{JavaHiveDecimalObjectInspector, JavaHiveVarcharObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} -import org.apache.spark.sql.hive._ -import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} +import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter} /** * :: DeveloperApi :: @@ -49,7 +51,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode with Command { + extends UnaryNode { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -99,74 +101,66 @@ case class InsertIntoHiveTable( } def saveAsHiveFile( - rdd: RDD[Row], + rdd: RDD[Writable], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: SerializableWritable[JobConf], - writerContainer: SparkHiveWriterContainer) { - assert(valueClass != null, "Output value class not set") - conf.value.setOutputValueClass(valueClass) - - val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName - assert(outputFileFormatClassName != null, "Output format class not set") - conf.value.set("mapred.output.format.class", outputFileFormatClassName) - - val isCompressed = conf.value.getBoolean( - ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) - + conf: JobConf, + isCompressed: Boolean) { + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + conf.setOutputValueClass(valueClass) + if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { + throw new SparkException("Output format class not set") + } + // Doesn't work in Scala 2.9 due to what may be a generics bug + // TODO: Should we uncomment this for Scala 2.10? + // conf.setOutputFormat(outputFormatClass) + conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) if (isCompressed) { // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", // and "mapred.output.compression.type" have no impact on ORC because it uses table properties // to store compression information. - conf.value.set("mapred.output.compress", "true") + conf.set("mapred.output.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(conf.value.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(conf.value.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) } - conf.value.setOutputCommitter(classOf[FileOutputCommitter]) - + conf.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath( - conf.value, - SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) - log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + conf, + SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) - writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, writeToFile _) - writerContainer.commitJob() - - // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]) { - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] + log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val outputData = new Array[Any](fieldOIs.length) + val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) + writer.preSetup() + def writeToFile(context: TaskContext, iter: Iterator[Writable]) { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber) - iterator.foreach { row => - var i = 0 - while (i < fieldOIs.length) { - // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` - outputData(i) = wrap(row(i), fieldOIs(i)) - i += 1 - } + writer.setup(context.stageId, context.partitionId, attemptNumber) + writer.open() - val writer = writerContainer.getLocalFileWriter(row) - writer.write(serializer.serialize(outputData, standardOI)) + var count = 0 + while(iter.hasNext) { + val record = iter.next() + count += 1 + writer.write(record) } - writerContainer.close() + writer.close() + writer.commit() } + + sc.sparkContext.runJob(rdd, writeToFile _) + writer.commitJob() } + override def execute() = result + /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -174,57 +168,50 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + private lazy val result: RDD[Row] = { + val childRdd = child.execute() + assert(childRdd != null) + // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) + val rdd = childRdd.mapPartitions { iter => + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] - val numDynamicPartitions = partition.values.count(_.isEmpty) - val numStaticPartitions = partition.values.count(_.nonEmpty) - val partitionSpec = partition.map { - case (key, Some(value)) => key -> value - case (key, None) => key -> "" - } - - // All partition column names in the format of "//..." - val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") - val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull - - // Validate partition spec if there exist any dynamic partitions - if (numDynamicPartitions > 0) { - // Report error if dynamic partitioning is not enabled - if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { - throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) - } - // Report error if dynamic partition strict mode is on but no static partition is found - if (numStaticPartitions == 0 && - sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { - throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) - } + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) + iter.map { row => + var i = 0 + while (i < row.length) { + // Casts Strings to HiveVarchars when necessary. + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - // Report error if any static partition appears after a dynamic partition - val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => - throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) + serializer.serialize(outputData, standardOI) } } + // ORC stores compression information in table properties. While, there are other formats + // (e.g. RCFile) that rely on hadoop configurations to store compression information. val jobConf = new JobConf(sc.hiveconf) - val jobConfSer = new SerializableWritable(jobConf) - - val writerContainer = if (numDynamicPartitions > 0) { - val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) - new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) - } else { - new SparkHiveWriterContainer(jobConf, fileSinkConf) - } - - saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) - + saveAsHiveFile( + rdd, + outputClass, + fileSinkConf, + jobConf, + sc.hiveconf.getBoolean("hive.exec.compress.output", false)) + + // TODO: Handle dynamic partitioning. val outputPath = FileOutputFormat.getOutputPath(jobConf) // Have to construct the format of dbname.tablename. val qualifiedTableName = s"${table.databaseName}.${table.tableName}" @@ -233,6 +220,10 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" // Should not reach here right now. + } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -240,26 +231,14 @@ case class InsertIntoHiveTable( val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false - if (numDynamicPartitions > 0) { - db.loadDynamicPartitions( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - numDynamicPartitions, - holdDDLTime, - isSkewedStoreAsSubdir - ) - } else { - db.loadPartition( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) - } + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) } else { db.loadTable( outputPath, @@ -272,6 +251,6 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[Row] + sc.sparkContext.makeRDD(Nil, 1) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala deleted file mode 100644 index a667188fa53bd..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException -import java.text.NumberFormat -import java.util.Date - -import scala.collection.mutable - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} -import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} -import org.apache.hadoop.hive.ql.plan.FileSinkDesc -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred._ - -import org.apache.spark.sql.Row -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} - -/** - * Internal helper class that saves an RDD using a Hive OutputFormat. - * It is based on [[SparkHadoopWriter]]. - */ -private[hive] class SparkHiveWriterContainer( - @transient jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { - - private val now = new Date() - protected val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private lazy val committer = conf.value.getOutputCommitter - @transient private lazy val jobContext = newJobContext(conf.value, jID.value) - @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) - @transient private lazy val outputFormat = - conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] - - def driverSideSetup() { - setIDs(0, 0, 0) - setConfParams() - committer.setupJob(jobContext) - } - - def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) { - setIDs(jobId, splitId, attemptId) - setConfParams() - committer.setupTask(taskContext) - initWriters() - } - - protected def getOutputName: String = { - val numberFormat = NumberFormat.getInstance() - numberFormat.setMinimumIntegerDigits(5) - numberFormat.setGroupingUsed(false) - val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) - "part-" + numberFormat.format(splitID) + extension - } - - def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer - - def close() { - // Seems the boolean value passed into close does not matter. - writer.close(false) - commit() - } - - def commitJob() { - committer.commitJob(jobContext) - } - - protected def initWriters() { - // NOTE this method is executed at the executor side. - // For Hive tables without partitions or with only static partitions, only 1 writer is needed. - writer = HiveFileFormatUtils.getHiveRecordWriter( - conf.value, - fileSinkConf.getTableInfo, - conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], - fileSinkConf, - FileOutputFormat.getTaskOutputPath(conf.value, getOutputName), - Reporter.NULL) - } - - protected def commit() { - if (committer.needsTaskCommit(taskContext)) { - try { - committer.commitTask(taskContext) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - committer.abortTask(taskContext) - throw e - } - } else { - logInfo("No need to commit output of task: " + taID.value) - } - } - - // ********* Private Functions ********* - - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { - jobID = jobId - splitID = splitId - attemptID = attemptId - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -private[hive] object SparkHiveWriterContainer { - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } -} - -private[spark] class SparkHiveDynamicPartitionWriterContainer( - @transient jobConf: JobConf, - fileSinkConf: FileSinkDesc, - dynamicPartColNames: Array[String]) - extends SparkHiveWriterContainer(jobConf, fileSinkConf) { - - private val defaultPartName = jobConf.get( - ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) - - @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ - - override protected def initWriters(): Unit = { - // NOTE: This method is executed at the executor side. - // Actual writers are created for each dynamic partition on the fly. - writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] - } - - override def close(): Unit = { - writers.values.foreach(_.close(false)) - commit() - } - - override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { - val dynamicPartPath = dynamicPartColNames - .zip(row.takeRight(dynamicPartColNames.length)) - .map { case (col, rawVal) => - val string = String.valueOf(rawVal) - s"/$col=${if (rawVal == null || string.isEmpty) defaultPartName else string}" - } - .mkString - - def newWriter = { - val newFileSinkDesc = new FileSinkDesc( - fileSinkConf.getDirName + dynamicPartPath, - fileSinkConf.getTableInfo, - fileSinkConf.getCompressed) - newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) - newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) - - val path = { - val outputPath = FileOutputFormat.getOutputPath(conf.value) - assert(outputPath != null, "Undefined job output-path") - val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) - new Path(workPath, getOutputName) - } - - HiveFileFormatUtils.getHiveRecordWriter( - conf.value, - fileSinkConf.getTableInfo, - conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], - newFileSinkDesc, - path, - Reporter.NULL) - } - - writers.getOrElseUpdate(dynamicPartPath, newWriter) - } -} diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 b/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 deleted file mode 100644 index 573541ac9702d..0000000000000 --- a/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 +++ /dev/null @@ -1 +0,0 @@ -0 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f b/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a b/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e b/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf b/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5d743a51b47c5..2da8a6fac3d99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.hive.execution import scala.util.Try -import org.apache.hadoop.hive.conf.HiveConf.ConfVars - -import org.apache.spark.SparkException import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -383,7 +380,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.contains("== Physical Plan ==") + explanation.exists(_ == "== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -571,91 +568,6 @@ class HiveQuerySuite extends HiveComparisonTest { case class LogEntry(filename: String, message: String) case class LogFile(name: String) - createQueryTest("dynamic_partition", - """ - |DROP TABLE IF EXISTS dynamic_part_table; - |CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT); - | - |SET hive.exec.dynamic.partition.mode=nonstrict; - | - |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) - |SELECT 1, 1, 1 FROM src WHERE key=150; - | - |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) - |SELECT 1, NULL, 1 FROM src WHERE key=150; - | - |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) - |SELECT 1, 1, NULL FROM src WHERE key=150; - | - |INSERT INTO TABLe dynamic_part_table PARTITION(partcol1, partcol2) - |SELECT 1, NULL, NULL FROM src WHERE key=150; - | - |DROP TABLE IF EXISTS dynamic_part_table; - """.stripMargin) - - test("Dynamic partition folder layout") { - sql("DROP TABLE IF EXISTS dynamic_part_table") - sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") - sql("SET hive.exec.dynamic.partition.mode=nonstrict") - - val data = Map( - Seq("1", "1") -> 1, - Seq("1", "NULL") -> 2, - Seq("NULL", "1") -> 3, - Seq("NULL", "NULL") -> 4) - - data.foreach { case (parts, value) => - sql( - s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) - |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 - """.stripMargin) - - val partFolder = Seq("partcol1", "partcol2") - .zip(parts) - .map { case (k, v) => - if (v == "NULL") { - s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" - } else { - s"$k=$v" - } - } - .mkString("/") - - // Loads partition data to a temporary table to verify contents - val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" - - sql("DROP TABLE IF EXISTS dp_verify") - sql("CREATE TABLE dp_verify(intcol INT)") - sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") - - assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) - } - } - - test("Partition spec validation") { - sql("DROP TABLE IF EXISTS dp_test") - sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") - sql("SET hive.exec.dynamic.partition.mode=strict") - - // Should throw when using strict dynamic partition mode without any static partition - intercept[SparkException] { - sql( - """INSERT INTO TABLE dp_test PARTITION(dp) - |SELECT key, value, key % 5 FROM src - """.stripMargin) - } - - sql("SET hive.exec.dynamic.partition.mode=nonstrict") - - // Should throw when a static partition appears after a dynamic partition - intercept[SparkException] { - sql( - """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) - |SELECT key, value, key % 5 FROM src - """.stripMargin) - } - } - test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") @@ -713,27 +625,27 @@ class HiveQuerySuite extends HiveComparisonTest { assert(sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) + collectResults(hql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) + collectResults(hql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + collectResults(hql("SET")) } // "set key" assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) + collectResults(hql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) + collectResults(hql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as sql() by repeating the above using sql(). From 157e7d0f62eaf016a0c3749065ddcec170540a36 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 30 Sep 2014 09:46:58 -0700 Subject: [PATCH 134/315] HOTFIX: Ignore flaky tests in YARN --- .../scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 857a4447dd738..4b6635679f053 100644 --- a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -86,13 +86,13 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { super.afterAll() } - test("run Spark in yarn-client mode") { + ignore("run Spark in yarn-client mode") { var result = File.createTempFile("result", null, tempDir) YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) checkResult(result) } - test("run Spark in yarn-cluster mode") { + ignore("run Spark in yarn-cluster mode") { val main = YarnClusterDriver.getClass.getName().stripSuffix("$") var result = File.createTempFile("result", null, tempDir) From ab6dd80ba0f7e1042ea270d10400109a467fe40e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 30 Sep 2014 11:15:38 -0700 Subject: [PATCH 135/315] [SPARK-3356] [DOCS] Document when RDD elements' ordering within partitions is nondeterministic As suggested by mateiz , and because it came up on the mailing list again last week, this attempts to document that ordering of elements is not guaranteed across RDD evaluations in groupBy, zip, and partition-wise RDD methods. Suggestions welcome about the wording, or other methods that need a note. Author: Sean Owen Closes #2508 from srowen/SPARK-3356 and squashes the following commits: b7c96fd [Sean Owen] Undo change to programming guide ad4aeec [Sean Owen] Don't mention ordering in partition-wise methods, reword description of ordering for zip methods per review, and add similar note to programming guide, which mentions groupByKey (but not zip methods) fce943b [Sean Owen] Note that ordering of elements is not guaranteed across RDD evaluations in groupBy, zip, and partition-wise RDD methods --- .../apache/spark/rdd/PairRDDFunctions.scala | 9 +++++++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 20 ++++++++++++++++--- docs/programming-guide.md | 2 +- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 67833743f3a98..929ded58a3bd5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -420,6 +420,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. + * The ordering of elements within each group is not guaranteed, and may even differ + * each time the resulting RDD is evaluated. * * Note: This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] @@ -439,7 +441,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numPartitions` partitions. + * resulting RDD with into `numPartitions` partitions. The ordering of elements within + * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * * Note: This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] @@ -535,7 +538,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the existing partitioner/parallelism level. + * resulting RDD with the existing partitioner/parallelism level. The ordering of elements + * within each group is not guaranteed, and may even differ each time the resulting RDD is + * evaluated. * * Note: This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ba712c9d7776f..ab9e97c8fe409 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -509,7 +509,8 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements - * mapping to that key. + * mapping to that key. The ordering of elements within each group is not guaranteed, and + * may even differ each time the resulting RDD is evaluated. * * Note: This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] @@ -520,7 +521,8 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. + * mapping to that key. The ordering of elements within each group is not guaranteed, and + * may even differ each time the resulting RDD is evaluated. * * Note: This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] @@ -531,7 +533,8 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements - * mapping to that key. + * mapping to that key. The ordering of elements within each group is not guaranteed, and + * may even differ each time the resulting RDD is evaluated. * * Note: This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] @@ -1028,8 +1031,14 @@ abstract class RDD[T: ClassTag]( * Zips this RDD with its element indices. The ordering is first based on the partition index * and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. + * * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. + * + * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * elements in a partition. The index assigned to each element is therefore not guaranteed, + * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee + * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithIndex(): RDD[(T, Long)] = new ZippedWithIndexRDD(this) @@ -1037,6 +1046,11 @@ abstract class RDD[T: ClassTag]( * Zips this RDD with generated unique Long ids. Items in the kth partition will get ids k, n+k, * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. + * + * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, + * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee + * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. */ def zipWithUniqueId(): RDD[(T, Long)] = { val n = this.partitions.size.toLong diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 510b47a2aaad1..1d61a3c555eaf 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -883,7 +883,7 @@ for details.
    + + + + + + + + + diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8d..b8cdbbe3cf2b6 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,6 +215,21 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8e7b00469e246..e9418320ff781 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,6 +20,7 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile +import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -30,7 +31,6 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel -from pyspark import rdd from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -192,6 +192,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() + # profiling stats collected for each PythonRDD + self._profile_stats = [] + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization @@ -792,6 +795,40 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) + def _add_profile(self, id, profileAcc): + if not self._profile_stats: + dump_path = self._conf.get("spark.python.profile.dump") + if dump_path: + atexit.register(self.dump_profiles, dump_path) + else: + atexit.register(self.show_profiles) + + self._profile_stats.append([id, profileAcc, False]) + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, acc, showed) in enumerate(self._profile_stats): + stats = acc.value + if not showed and stats: + print "=" * 60 + print "Profile of RDD" % id + print "=" * 60 + stats.sort_stats("time", "cumulative").print_stats() + # mark it as showed + self._profile_stats[i][2] = True + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` + """ + if not os.path.exists(path): + os.makedirs(path) + for id, acc, _ in self._profile_stats: + stats = acc.value + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + self._profile_stats = [] + def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 680140d72d03c..8ed89e2f9769f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -15,7 +15,6 @@ # limitations under the License. # -from base64 import standard_b64encode as b64enc import copy from collections import defaultdict from itertools import chain, ifilter, imap @@ -32,6 +31,7 @@ from random import Random from math import sqrt, log, isinf, isnan +from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -2080,7 +2080,9 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - command = (self.func, self._prev_jrdd_deserializer, + enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" + profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None + command = (self.func, profileStats, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2102,6 +2104,10 @@ def _jrdd(self): self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() + + if enable_profile: + self._id = self._jrdd_val.id() + self.ctx._add_profile(self._id, profileStats) return self._jrdd_val def id(self): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f71d24c470dc9..d8bdf22355ec8 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -960,7 +960,7 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, + command = (func, None, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 29df754c6fd29..7e2bbc9cb617f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -632,6 +632,36 @@ def test_distinct(self): self.assertEquals(result.count(), 3) +class TestProfiler(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + + def test_profiler(self): + + def heavy_foo(x): + for i in range(1 << 20): + x = 1 + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + profiles = self.sc._profile_stats + self.assertEqual(1, len(profiles)) + id, acc, _ = profiles[0] + stats = acc.value + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue("heavy_foo" in func_names) + + self.sc.show_profiles() + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + + class TestSQL(PySparkTestCase): def setUp(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c1f6e3e4a1f40..8257dddfee1c3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,8 @@ import time import socket import traceback +import cProfile +import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -90,10 +92,21 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, deserializer, serializer) = command + (func, stats, deserializer, serializer) = command init_time = time.time() - iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) + + def process(): + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) + + if stats: + p = cProfile.Profile() + p.runcall(process) + st = pstats.Stats(p) + st.stream = None # make it picklable + stats.add(st.strip_dirs()) + else: + process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) From eb43043f411b87b7b412ee31e858246bd93fdd04 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Oct 2014 00:29:14 -0700 Subject: [PATCH 142/315] [SPARK-3747] TaskResultGetter could incorrectly abort a stage if it cannot get result for a specific task Author: Reynold Xin Closes #2599 from rxin/SPARK-3747 and squashes the following commits: a74c04d [Reynold Xin] Added a line of comment explaining NonFatal 0e8d44c [Reynold Xin] [SPARK-3747] TaskResultGetter could incorrectly abort a stage if it cannot get result for a specific task --- .../org/apache/spark/scheduler/TaskResultGetter.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index df59f444b7a0e..3f345ceeaaf7a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import scala.util.control.NonFatal + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.serializer.SerializerInstance @@ -32,7 +34,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( - THREADS, "Result resolver thread") + THREADS, "task-result-getter") protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { @@ -70,7 +72,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul case cnf: ClassNotFoundException => val loader = Thread.currentThread.getContextClassLoader taskSetManager.abort("ClassNotFound with classloader: " + loader) - case ex: Exception => + // Matching NonFatal so we don't catch the ControlThrowable from the "return" above. + case NonFatal(ex) => logError("Exception while getting task result", ex) taskSetManager.abort("Exception while getting task result: %s".format(ex)) } From 7bf6cc9701cbb0f77fb85a412e387fb92274fca5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 1 Oct 2014 01:03:24 -0700 Subject: [PATCH 143/315] [SPARK-3751] [mllib] DecisionTree: example update + print options DecisionTreeRunner functionality additions: * Allow user to pass in a test dataset * Do not print full model if the model is too large. As part of this, modify DecisionTreeModel and RandomForestModel to allow printing less info. Proposed updates: * toString: prints model summary * toDebugString: prints full model (named after RDD.toDebugString) Similar update to Python API: * __repr__() now prints a model summary * toDebugString() now prints the full model CC: mengxr chouqin manishamde codedeft Small update (whomever can take a look). Thanks! Author: Joseph K. Bradley Closes #2604 from jkbradley/dtrunner-update and squashes the following commits: b2b3c60 [Joseph K. Bradley] re-added python sql doc test, temporarily removed before 07b1fae [Joseph K. Bradley] repr() now prints a model summary toDebugString() now prints the full model 1d0d93d [Joseph K. Bradley] Updated DT and RF to print less when toString is called. Added toDebugString for verbose printing. 22eac8c [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update e007a95 [Joseph K. Bradley] Updated DecisionTreeRunner to accept a test dataset. --- .../examples/mllib/DecisionTreeRunner.scala | 99 ++++++++++++++----- .../mllib/tree/model/DecisionTreeModel.scala | 14 ++- .../mllib/tree/model/RandomForestModel.scala | 30 ++++-- python/pyspark/mllib/tree.py | 10 +- 4 files changed, 111 insertions(+), 42 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 96fb068e9e126..4adc91d2fbe65 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -52,6 +52,7 @@ object DecisionTreeRunner { case class Params( input: String = null, + testInput: String = "", dataFormat: String = "libsvm", algo: Algo = Classification, maxDepth: Int = 5, @@ -98,13 +99,18 @@ object DecisionTreeRunner { s"default: ${defaultParams.featureSubsetStrategy}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) opt[String]("") .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") .action((x, c) => c.copy(dataFormat = x)) arg[String]("") - .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") + .text("input path to labeled examples") .required() .action((x, c) => c.copy(input = x)) checkConfig { params => @@ -141,7 +147,7 @@ object DecisionTreeRunner { case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() } // For classification, re-index classes if needed. - val (examples, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = params.algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -170,16 +176,40 @@ object DecisionTreeRunner { val frac = classCounts(c) / numExamples.toDouble println(s"$c\t$frac\t${classCounts(c)}") } - (examples, numClasses) + (examples, classIndexMap, numClasses) } case Regression => - (origExamples, 0) + (origExamples, null, 0) case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - // Split into training, test. - val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + // Create training, test sets. + val splits = if (params.testInput != "") { + // Load testInput. + val origTestExamples = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput) + } + params.algo match { + case Classification => { + // classCounts: class --> # examples in class + val testExamples = { + if (classIndexMap.isEmpty) { + origTestExamples + } else { + origTestExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features)) + } + } + Array(examples, testExamples) + } + case Regression => + Array(examples, origTestExamples) + } + } else { + // Split input into training, test. + examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + } val training = splits(0).cache() val test = splits(1).cache() val numTraining = training.count() @@ -206,47 +236,62 @@ object DecisionTreeRunner { minInfoGain = params.minInfoGain) if (params.numTrees == 1) { val model = DecisionTree.train(training, strategy) - println(model) + if (model.numNodes < 20) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } if (params.algo == Classification) { - val accuracy = + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision - println(s"Test accuracy = $accuracy") + println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse") + val trainMSE = meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") } } else { val randomSeed = Utils.random.nextInt() if (params.algo == Classification) { val model = RandomForest.trainClassifier(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) - println(model) - val accuracy = + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision - println(s"Test accuracy = $accuracy") + println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { val model = RandomForest.trainRegressor(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) - println(model) - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") } } sc.stop() } - /** - * Calculates the classifier accuracy. - */ - private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - val correctCount = data.filter(y => model.predict(y.features) == y.label).count() - val count = data.count() - correctCount.toDouble / count - } - /** * Calculates the mean squared error for regression. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 271b2c4ad813e..ec1d99ab26f9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -68,15 +68,23 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable } /** - * Print full model. + * Print a summary of the model. */ override def toString: String = algo match { case Classification => - s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2) + s"DecisionTreeModel classifier of depth $depth with $numNodes nodes" case Regression => - s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2) + s"DecisionTreeModel regressor of depth $depth with $numNodes nodes" case _ => throw new IllegalArgumentException( s"DecisionTreeModel given unknown algo parameter: $algo.") } + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + topNode.subtreeToString(2) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala index 538c0e233202a..4d66d6d81caa5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -73,17 +73,27 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext def numTrees: Int = trees.size /** - * Print full model. + * Get total number of nodes, summed over all trees in the forest. */ - override def toString: String = { - val header = algo match { - case Classification => - s"RandomForestModel classifier with $numTrees trees\n" - case Regression => - s"RandomForestModel regressor with $numTrees trees\n" - case _ => throw new IllegalArgumentException( - s"RandomForestModel given unknown algo parameter: $algo.") - } + def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum + + /** + * Print a summary of the model. + */ + override def toString: String = algo match { + case Classification => + s"RandomForestModel classifier with $numTrees trees" + case Regression => + s"RandomForestModel regressor with $numTrees trees" + case _ => throw new IllegalArgumentException( + s"RandomForestModel given unknown algo parameter: $algo.") + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" header + trees.zipWithIndex.map { case (tree, treeIndex) => s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) }.fold("")(_ + _) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index f59a818a6e74d..afdcdbdf3ae01 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -77,8 +77,13 @@ def depth(self): return self._java_model.depth() def __repr__(self): + """ Print summary of model. """ return self._java_model.toString() + def toDebugString(self): + """ Print full model. """ + return self._java_model.toDebugString() + class DecisionTree(object): @@ -135,7 +140,6 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, >>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree - >>> from pyspark.mllib.linalg import SparseVector >>> >>> data = [ ... LabeledPoint(0.0, [0.0]), @@ -145,7 +149,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) >>> print model, # it already has newline - DecisionTreeModel classifier + DecisionTreeModel classifier of depth 1 with 3 nodes + >>> print model.toDebugString(), # it already has newline + DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) From 3888ee2f3875f7053f63f70190670247e5c77383 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Oct 2014 01:03:49 -0700 Subject: [PATCH 144/315] [SPARK-3748] Log thread name in unit test logs Thread names are useful for correlating failures. Author: Reynold Xin Closes #2600 from rxin/log4j and squashes the following commits: 83ffe88 [Reynold Xin] [SPARK-3748] Log thread name in unit test logs --- bagel/src/test/resources/log4j.properties | 2 +- core/src/test/resources/log4j.properties | 2 +- external/flume/src/test/resources/log4j.properties | 2 +- external/kafka/src/test/resources/log4j.properties | 2 +- external/mqtt/src/test/resources/log4j.properties | 2 +- external/twitter/src/test/resources/log4j.properties | 2 +- external/zeromq/src/test/resources/log4j.properties | 2 +- extras/java8-tests/src/test/resources/log4j.properties | 2 +- extras/kinesis-asl/src/test/resources/log4j.properties | 2 +- graphx/src/test/resources/log4j.properties | 2 +- mllib/src/test/resources/log4j.properties | 2 +- repl/src/test/resources/log4j.properties | 2 +- sql/core/src/test/resources/log4j.properties | 2 +- sql/hive/src/test/resources/log4j.properties | 2 +- streaming/src/test/resources/log4j.properties | 2 +- yarn/stable/src/test/resources/log4j.properties | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties index 30b4baa4d714a..789869f72e3b0 100644 --- a/bagel/src/test/resources/log4j.properties +++ b/bagel/src/test/resources/log4j.properties @@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index 26b73a1b39744..9dd05f17f012b 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 45d2ec676df66..4411d6e20c52a 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 45d2ec676df66..4411d6e20c52a 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties index 45d2ec676df66..4411d6e20c52a 100644 --- a/external/mqtt/src/test/resources/log4j.properties +++ b/external/mqtt/src/test/resources/log4j.properties @@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties index 45d2ec676df66..4411d6e20c52a 100644 --- a/external/twitter/src/test/resources/log4j.properties +++ b/external/twitter/src/test/resources/log4j.properties @@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties index 45d2ec676df66..4411d6e20c52a 100644 --- a/external/zeromq/src/test/resources/log4j.properties +++ b/external/zeromq/src/test/resources/log4j.properties @@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties index 180beaa8cc5a7..bb0ab319a0080 100644 --- a/extras/java8-tests/src/test/resources/log4j.properties +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties index e01e049595475..d9d08f68687d3 100644 --- a/extras/kinesis-asl/src/test/resources/log4j.properties +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -20,7 +20,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index 26b73a1b39744..9dd05f17f012b 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index ddfc4ac6b23ed..a469badf603c6 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index 9c4896e49698c..52098993f5c3c 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index c7e0ff1cf6494..fbed0a782dd3e 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -30,7 +30,7 @@ log4j.appender.FA=org.apache.log4j.FileAppender log4j.appender.FA.append=false log4j.appender.FA.file=target/unit-tests.log log4j.appender.FA.layout=org.apache.log4j.PatternLayout -log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n # Set the logger level of File Appender to WARN log4j.appender.FA.Threshold = INFO diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties index c07d8fedf1993..9fdb526d945e0 100644 --- a/sql/hive/src/test/resources/log4j.properties +++ b/sql/hive/src/test/resources/log4j.properties @@ -30,7 +30,7 @@ log4j.appender.FA=org.apache.log4j.FileAppender log4j.appender.FA.append=false log4j.appender.FA.file=target/unit-tests.log log4j.appender.FA.layout=org.apache.log4j.PatternLayout -log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n # Set the logger level of File Appender to WARN log4j.appender.FA.Threshold = INFO diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 45d2ec676df66..4411d6e20c52a 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN diff --git a/yarn/stable/src/test/resources/log4j.properties b/yarn/stable/src/test/resources/log4j.properties index 26b73a1b39744..9dd05f17f012b 100644 --- a/yarn/stable/src/test/resources/log4j.properties +++ b/yarn/stable/src/test/resources/log4j.properties @@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN From 0bfd3afb00936b0f46ba613be0982e38bc7032b5 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Wed, 1 Oct 2014 08:55:04 -0700 Subject: [PATCH 145/315] [SPARK-3757] mvn clean doesn't delete some files Added directory to be deleted into maven-clean-plugin in pom.xml. Author: Masayoshi TSUZUKI Closes #2613 from tsudukim/feature/SPARK-3757 and squashes the following commits: 8804bfc [Masayoshi TSUZUKI] Modified indent. 67c7171 [Masayoshi TSUZUKI] [SPARK-3757] mvn clean doesn't delete some files --- core/pom.xml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index e012c5e673b74..a5a178079bc57 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -322,6 +322,17 @@ + + maven-clean-plugin + + + + ${basedir}/../python/build + + + true + + org.apache.maven.plugins maven-shade-plugin From abf588f47a26d0066f0b75d52b200a87bb085064 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 11:21:34 -0700 Subject: [PATCH 146/315] [SPARK-3749] [PySpark] fix bugs in broadcast large closure of RDD 1. broadcast is triggle unexpected 2. fd is leaked in JVM (also leak in parallelize()) 3. broadcast is not unpersisted in JVM after RDD is not be used any more. cc JoshRosen , sorry for these stupid bugs. Author: Davies Liu Closes #2603 from davies/fix_broadcast and squashes the following commits: 080a743 [Davies Liu] fix bugs in broadcast large closure of RDD --- .../apache/spark/api/python/PythonRDD.scala | 34 ++++++++++++------- python/pyspark/rdd.py | 12 +++++-- python/pyspark/sql.py | 2 +- python/pyspark/tests.py | 8 +++-- 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f9ff4ea6ca157..924141475383d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -339,26 +339,34 @@ private[spark] object PythonRDD extends Logging { def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] try { - while (true) { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - objs.append(obj) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} } - } catch { - case eof: EOFException => {} + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } finally { + file.close() } - JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - sc.broadcast(obj) + try { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } finally { + file.close() + } } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ed89e2f9769f..dc6497772e502 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2073,6 +2073,12 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None + self._broadcast = None + + def __del__(self): + if self._broadcast: + self._broadcast.unpersist() + self._broadcast = None @property def _jrdd(self): @@ -2087,9 +2093,9 @@ def _jrdd(self): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if pickled_command > (1 << 20): # 1M - broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) + if len(pickled_command) > (1 << 20): # 1M + self._broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(self._broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d8bdf22355ec8..974b5e287bc00 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -965,7 +965,7 @@ def registerFunction(self, name, f, returnType=StringType()): BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if pickled_command > (1 << 20): # 1M + if len(pickled_command) > (1 << 20): # 1M broadcast = self._sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7e2bbc9cb617f..6fb6bc998c752 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -467,8 +467,12 @@ def test_large_broadcast(self): def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] - m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum() - self.assertEquals(N, m) + rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) + self.assertEquals(N, rdd.first()) + self.assertTrue(rdd._broadcast is not None) + rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) + self.assertEqual(1, rdd.first()) + self.assertTrue(rdd._broadcast is None) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) From dcb2f73f1cf1f6efd5175267e135ad6cf4bf6e3d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 1 Oct 2014 11:28:22 -0700 Subject: [PATCH 147/315] SPARK-2626 [DOCS] Stop SparkContext in all examples Call SparkContext.stop() in all examples (and touch up minor nearby code style issues while at it) Author: Sean Owen Closes #2575 from srowen/SPARK-2626 and squashes the following commits: 5b2baae [Sean Owen] Call SparkContext.stop() in all examples (and touch up minor nearby code style issues while at it) --- .../main/java/org/apache/spark/examples/JavaSparkPi.java | 3 ++- .../java/org/apache/spark/examples/sql/JavaSparkSQL.java | 9 ++++++++- examples/src/main/python/avro_inputformat.py | 2 ++ examples/src/main/python/parquet_inputformat.py | 2 ++ .../org/apache/spark/examples/CassandraCQLTest.scala | 2 ++ .../scala/org/apache/spark/examples/CassandraTest.scala | 2 ++ .../scala/org/apache/spark/examples/GroupByTest.scala | 6 +++--- .../main/scala/org/apache/spark/examples/LogQuery.scala | 2 ++ .../apache/spark/examples/bagel/WikipediaPageRank.scala | 9 +++++---- .../org/apache/spark/examples/sql/RDDRelation.scala | 4 +++- .../apache/spark/examples/sql/hive/HiveFromSpark.scala | 4 +++- 11 files changed, 34 insertions(+), 11 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 11157d7573fae..0f07cb4098325 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -31,7 +31,6 @@ * Usage: JavaSparkPi [slices] */ public final class JavaSparkPi { - public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi"); @@ -61,5 +60,7 @@ public Integer call(Integer integer, Integer integer2) { }); System.out.println("Pi is roughly " + 4.0 * count / n); + + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 898297dc658ba..01c77bd44337e 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -61,7 +61,8 @@ public static void main(String[] args) throws Exception { // Load a text file and convert each line to a Java Bean. JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( new Function() { - public Person call(String line) throws Exception { + @Override + public Person call(String line) { String[] parts = line.split(","); Person person = new Person(); @@ -82,6 +83,7 @@ public Person call(String line) throws Exception { // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.map(new Function() { + @Override public String call(Row row) { return "Name: " + row.getString(0); } @@ -104,6 +106,7 @@ public String call(Row row) { JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.map(new Function() { + @Override public String call(Row row) { return "Name: " + row.getString(0); } @@ -136,6 +139,7 @@ public String call(Row row) { // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagerNames = teenagers3.map(new Function() { + @Override public String call(Row row) { return "Name: " + row.getString(0); } }).collect(); for (String name: teenagerNames) { @@ -162,6 +166,7 @@ public String call(Row row) { JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.map(new Function() { + @Override public String call(Row row) { return "Name: " + row.getString(0) + ", City: " + row.getString(1); } @@ -169,5 +174,7 @@ public String call(Row row) { for (String name: nameAndCity) { System.out.println(name); } + + ctx.stop(); } } diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index cfda8d8327aa3..4626bbb7e3b02 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -78,3 +78,5 @@ output = avro_rdd.map(lambda x: x[0]).collect() for k in output: print k + + sc.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index c9b08f878a1e6..fa4c20ab20281 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -57,3 +57,5 @@ output = parquet_rdd.map(lambda x: x[1]).collect() for k in output: print k + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 71f53af68f4d3..11d5c92c5952d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -136,5 +136,7 @@ object CassandraCQLTest { classOf[CqlOutputFormat], job.getConfiguration() ) + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 91ba364a346a5..ec689474aecb0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -126,6 +126,8 @@ object CassandraTest { } }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], classOf[ColumnFamilyOutputFormat], job.getConfiguration) + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index efd91bb054981..15f6678648b29 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -44,11 +44,11 @@ object GroupByTest { arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) } arr1 - }.cache + }.cache() // Enforce that everything has been calculated and in cache - pairs1.count + pairs1.count() - println(pairs1.groupByKey(numReducers).count) + println(pairs1.groupByKey(numReducers).count()) sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 4c655b84fde2e..74620ad007d83 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -79,5 +79,7 @@ object LogQuery { .reduceByKey((a, b) => a.merge(b)) .collect().foreach{ case (user, query) => println("%s\t%s".format(user, query))} + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala index 235c3bf820244..e4db3ec51313d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -21,7 +21,6 @@ import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.bagel._ -import org.apache.spark.bagel.Bagel._ import scala.xml.{XML,NodeSeq} @@ -78,9 +77,9 @@ object WikipediaPageRank { (id, new PRVertex(1.0 / numVertices, outEdges)) }) if (usePartitioner) { - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache + vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache() } else { - vertices = vertices.cache + vertices = vertices.cache() } println("Done parsing input file.") @@ -100,7 +99,9 @@ object WikipediaPageRank { (result .filter { case (id, vertex) => vertex.value >= threshold } .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect.mkString) + .collect().mkString) println(top) + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index d56d64c564200..2e98b2dc30b80 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -51,7 +51,7 @@ object RDDRelation { val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10") println("Result of RDD.map:") - rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect.foreach(println) + rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) @@ -68,5 +68,7 @@ object RDDRelation { // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") sql("SELECT * FROM parquetFile").collect().foreach(println) + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 3423fac0ad303..e26f213e8afa8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -39,7 +39,7 @@ object HiveFromSpark { // Queries are expressed in HiveQL println("Result of 'SELECT *': ") - sql("SELECT * FROM src").collect.foreach(println) + sql("SELECT * FROM src").collect().foreach(println) // Aggregation queries are also supported. val count = sql("SELECT COUNT(*) FROM src").collect().head.getLong(0) @@ -61,5 +61,7 @@ object HiveFromSpark { // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) + + sc.stop() } } From 6390aae4eacbabfb1c53fb828b824f6a6518beff Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 1 Oct 2014 11:30:29 -0700 Subject: [PATCH 148/315] [SPARK-3755][Core] Do not bind port 1 - 1024 to server in spark Non-root user use port 1- 1024 to start jetty server will get the exception " java.net.SocketException: Permission denied", so not use these ports Author: scwf Closes #2610 from scwf/1-1024 and squashes the following commits: cb8cc76 [scwf] do not use port 1 - 1024 --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index dbe0cfa2b8ff9..1d80012dc142d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1437,7 +1437,7 @@ private[spark] object Utils extends Logging { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port - val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536 + val tryPort = if (startPort == 0) startPort else (startPort + offset) % (65536 - 1024) + 1024 try { val (service, port) = startService(tryPort) logInfo(s"Successfully started service$serviceString on port $port.") From 2fedb5dddcc10d3186f49fc4996a7bb5b68bbc85 Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 1 Oct 2014 11:51:30 -0700 Subject: [PATCH 149/315] [SPARK-3756] [Core]check exception is caused by an address-port collision properly Jetty server use MultiException to handle exceptions when start server refer https://github.com/eclipse/jetty.project/blob/jetty-8.1.14.v20131031/jetty-server/src/main/java/org/eclipse/jetty/server/Server.java So in ```isBindCollision``` add the logical to cover MultiException Author: scwf Closes #2611 from scwf/fix-isBindCollision and squashes the following commits: 984cb12 [scwf] optimize the fix 3a6c849 [scwf] fix bug in isBindCollision --- core/src/main/scala/org/apache/spark/util/Utils.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1d80012dc142d..e5b83c069d961 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -23,6 +23,8 @@ import java.nio.ByteBuffer import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} +import org.eclipse.jetty.util.MultiException + import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer @@ -1470,6 +1472,7 @@ private[spark] object Utils extends Logging { return true } isBindCollision(e.getCause) + case e: MultiException => e.getThrowables.exists(isBindCollision) case e: Exception => isBindCollision(e.getCause) case _ => false } From 8cc70e7e15fd800f31b94e9102069506360289db Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 1 Oct 2014 12:40:37 -0700 Subject: [PATCH 150/315] [SQL] Kill dangerous trailing space in query string MD5 of query strings in `createQueryTest` calls are used to generate golden files, leaving trailing spaces there can be really dangerous. Got bitten by this while working on #2616: my "smart" IDE automatically removed a trailing space and makes Jenkins fail. (Really should add "no trailing space" to our coding style guidelines!) Author: Cheng Lian Closes #2619 from liancheng/kill-trailing-space and squashes the following commits: 034f119 [Cheng Lian] Kill dangerous trailing space in query string --- ...tamp to Timestamp in UDF-0-db6d4503454e4dbb9edcbab9a8718d7f} | 0 .../org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename sql/hive/src/test/resources/golden/{Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa => Cast Timestamp to Timestamp in UDF-0-db6d4503454e4dbb9edcbab9a8718d7f} (100%) diff --git a/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa b/sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-db6d4503454e4dbb9edcbab9a8718d7f similarity index 100% rename from sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-66952a3949d7544716fd1a675498b1fa rename to sql/hive/src/test/resources/golden/Cast Timestamp to Timestamp in UDF-0-db6d4503454e4dbb9edcbab9a8718d7f diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2da8a6fac3d99..f5868bff22f13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -164,7 +164,7 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("Cast Timestamp to Timestamp in UDF", """ - | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) + | SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp)) | FROM src LIMIT 1 """.stripMargin) From b81ee0b46d63c2122b88941696654100fd736942 Mon Sep 17 00:00:00 2001 From: Gaspar Munoz Date: Wed, 1 Oct 2014 13:47:22 -0700 Subject: [PATCH 151/315] Typo error in KafkaWordCount example topicpMap to topicMap Author: Gaspar Munoz Closes #2614 from gasparms/patch-1 and squashes the following commits: 00aab2c [Gaspar Munoz] Typo error in KafkaWordCount example --- .../org/apache/spark/examples/streaming/KafkaWordCount.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index 566ba6f911e02..c9e1511278ede 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -53,8 +53,8 @@ object KafkaWordCount { val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") - val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicpMap).map(_._2) + val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap + val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1L)) .reduceByKeyAndWindow(_ + _, _ - _, Minutes(10), Seconds(2), 2) From 17333c7a3c26ca6d28e8f3ca097da37d6b655217 Mon Sep 17 00:00:00 2001 From: jyotiska Date: Wed, 1 Oct 2014 13:52:50 -0700 Subject: [PATCH 152/315] Python SQL Example Code SQL example code for Python, as shown on [SQL Programming Guide](https://spark.apache.org/docs/1.0.2/sql-programming-guide.html) Author: jyotiska Closes #2521 from jyotiska/sql_example and squashes the following commits: 1471dcb [jyotiska] added imports for sql b25e436 [jyotiska] pep 8 compliance 43fd10a [jyotiska] lines broken to maintain 80 char limit b4fdf4e [jyotiska] removed blank lines 83d5ab7 [jyotiska] added inferschema and applyschema to the demo 306667e [jyotiska] replaced blank line with end line c90502a [jyotiska] fixed new line 4939a70 [jyotiska] added new line at end for python style 0b46148 [jyotiska] fixed appname for python sql example 8f67b5b [jyotiska] added python sql example --- examples/src/main/python/sql.py | 73 +++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 examples/src/main/python/sql.py diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py new file mode 100644 index 0000000000000..eefa022f1927c --- /dev/null +++ b/examples/src/main/python/sql.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.sql import Row, StructField, StructType, StringType, IntegerType + + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSQL") + sqlContext = SQLContext(sc) + + # RDD is created from a list of rows + some_rdd = sc.parallelize([Row(name="John", age=19), + Row(name="Smith", age=23), + Row(name="Sarah", age=18)]) + # Infer schema from the first row, create a SchemaRDD and print the schema + some_schemardd = sqlContext.inferSchema(some_rdd) + some_schemardd.printSchema() + + # Another RDD is created from a list of tuples + another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) + # Schema with two fields - person_name and person_age + schema = StructType([StructField("person_name", StringType(), False), + StructField("person_age", IntegerType(), False)]) + # Create a SchemaRDD by applying the schema to the RDD and print the schema + another_schemardd = sqlContext.applySchema(another_rdd, schema) + another_schemardd.printSchema() + # root + # |-- age: integer (nullable = true) + # |-- name: string (nullable = true) + + # A JSON dataset is pointed to by path. + # The path can be either a single text file or a directory storing text files. + path = os.environ['SPARK_HOME'] + "examples/src/main/resources/people.json" + # Create a SchemaRDD from the file(s) pointed to by path + people = sqlContext.jsonFile(path) + # root + # |-- person_name: string (nullable = false) + # |-- person_age: integer (nullable = false) + + # The inferred schema can be visualized using the printSchema() method. + people.printSchema() + # root + # |-- age: IntegerType + # |-- name: StringType + + # Register this SchemaRDD as a table. + people.registerAsTable("people") + + # SQL statements can be run by using the sql methods provided by sqlContext + teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + + for each in teenagers.collect(): + print each[0] + + sc.stop() From fcad3fae6135bc2c9bdaf9e4c3cfe24838e63eae Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 1 Oct 2014 14:37:27 -0700 Subject: [PATCH 153/315] [SPARK-3746][SQL] Lock hive client when creating tables Author: Michael Armbrust Closes #2598 from marmbrus/hiveClientLock and squashes the following commits: ca89fe8 [Michael Armbrust] Lock hive client when creating tables --- .../apache/spark/sql/hive/HiveMetastoreCatalog.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9a0b9b46ac4ee..06b1446ccbd39 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -96,10 +96,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with serDeInfo.setParameters(Map[String, String]()) sd.setSerdeInfo(serDeInfo) - try client.createTable(table) catch { - case e: org.apache.hadoop.hive.ql.metadata.HiveException - if e.getCause.isInstanceOf[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] && - allowExisting => // Do nothing. + synchronized { + try client.createTable(table) catch { + case e: org.apache.hadoop.hive.ql.metadata.HiveException + if e.getCause.isInstanceOf[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] && + allowExisting => // Do nothing. + } } } From d61f2c15bb22253bfdda77462b1bd383987d2f5a Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Wed, 1 Oct 2014 15:15:09 -0700 Subject: [PATCH 154/315] [SPARK-3658][SQL] Start thrift server as a daemon https://issues.apache.org/jira/browse/SPARK-3658 And keep the `CLASS_NOT_FOUND_EXIT_STATUS` and exit message in `SparkSubmit.scala`. Author: WangTaoTheTonic Author: WangTao Closes #2509 from WangTaoTheTonic/thriftserver and squashes the following commits: 5dcaab2 [WangTaoTheTonic] issue about coupling 8ad9f95 [WangTaoTheTonic] generalization 598e21e [WangTao] take thrift server as a daemon --- bin/spark-sql | 12 +-------- .../org/apache/spark/deploy/SparkSubmit.scala | 4 +++ sbin/spark-daemon.sh | 16 ++++++++---- sbin/start-thriftserver.sh | 16 ++---------- sbin/stop-thriftserver.sh | 25 +++++++++++++++++++ 5 files changed, 43 insertions(+), 30 deletions(-) create mode 100755 sbin/stop-thriftserver.sh diff --git a/bin/spark-sql b/bin/spark-sql index 9d66140b6aa17..63d00437d508d 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -24,7 +24,6 @@ set -o posix CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" -CLASS_NOT_FOUND_EXIT_STATUS=101 # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -53,13 +52,4 @@ source "$FWDIR"/bin/utils.sh SUBMIT_USAGE_FUNCTION=usage gatherSparkSubmitOpts "$@" -"$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" -exit_status=$? - -if [[ exit_status -eq CLASS_NOT_FOUND_EXIT_STATUS ]]; then - echo - echo "Failed to load Spark SQL CLI main class $CLASS." - echo "You need to build Spark with -Phive." -fi - -exit $exit_status +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 580a439c9a892..f97bf67fa5a3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -320,6 +320,10 @@ object SparkSubmit { } catch { case e: ClassNotFoundException => e.printStackTrace(printStream) + if (childMainClass.contains("thriftserver")) { + println(s"Failed to load main class $childMainClass.") + println("You need to build Spark with -Phive.") + } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index bd476b400e1c3..cba475e2dd8c8 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -62,7 +62,7 @@ then shift fi -startStop=$1 +option=$1 shift command=$1 shift @@ -122,9 +122,9 @@ if [ "$SPARK_NICENESS" = "" ]; then fi -case $startStop in +case $option in - (start) + (start|spark-submit) mkdir -p "$SPARK_PID_DIR" @@ -142,8 +142,14 @@ case $startStop in spark_rotate_log "$log" echo starting $command, logging to $log - cd "$SPARK_PREFIX" - nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & + if [ $option == spark-submit ]; then + source "$SPARK_HOME"/bin/utils.sh + gatherSparkSubmitOpts "$@" + nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-submit --class $command \ + "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" >> "$log" 2>&1 < /dev/null & + else + nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null & + fi newpid=$! echo $newpid > $pid sleep 2 diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index ba953e763faab..50e8e06418b07 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -27,7 +27,6 @@ set -o posix FWDIR="$(cd "`dirname "$0"`"/..; pwd)" CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" -CLASS_NOT_FOUND_EXIT_STATUS=101 function usage { echo "Usage: ./sbin/start-thriftserver [options] [thrift server options]" @@ -49,17 +48,6 @@ if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then exit 0 fi -source "$FWDIR"/bin/utils.sh -SUBMIT_USAGE_FUNCTION=usage -gatherSparkSubmitOpts "$@" +export SUBMIT_USAGE_FUNCTION=usage -"$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" -exit_status=$? - -if [[ exit_status -eq CLASS_NOT_FOUND_EXIT_STATUS ]]; then - echo - echo "Failed to load Hive Thrift server main class $CLASS." - echo "You need to build Spark with -Phive." -fi - -exit $exit_status +exec "$FWDIR"/sbin/spark-daemon.sh spark-submit $CLASS 1 "$@" diff --git a/sbin/stop-thriftserver.sh b/sbin/stop-thriftserver.sh new file mode 100755 index 0000000000000..4031a00d4a689 --- /dev/null +++ b/sbin/stop-thriftserver.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the thrift server on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 1 From 3508ce8a5a05d6cb122ad59ba33c3cc18e17e2a4 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Wed, 1 Oct 2014 15:44:41 -0700 Subject: [PATCH 155/315] [SPARK-3708][SQL] Backticks aren't handled correctly is aliases The below query gives error sql("SELECT k FROM (SELECT \`key\` AS \`k\` FROM src) a") It gives error because the aliases are not cleaned so it could not be resolved in further processing. Author: ravipesala Closes #2594 from ravipesala/SPARK-3708 and squashes the following commits: d55db54 [ravipesala] Fixed SPARK-3708 (Backticks aren't handled correctly is aliases) --- .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0aa6292c0184e..4f3f808c93dc8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -855,7 +855,7 @@ private[hive] object HiveQl { case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), alias)()) + Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) /* Hints are ignored */ case Token("TOK_HINTLIST", _) => None diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 679efe082f2a0..3647bb1c4ce7d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -63,4 +63,10 @@ class SQLQuerySuite extends QueryTest { sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) } + + test("SPARK-3708 Backticks aren't handled correctly is aliases") { + checkAnswer( + sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), + sql("SELECT `key` FROM src").collect().toSeq) + } } From f315fb7efc95afb2cc1208159b48359ba56a010d Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 1 Oct 2014 15:55:09 -0700 Subject: [PATCH 156/315] [SPARK-3705][SQL] Add case for VoidObjectInspector to cover NullType add case for VoidObjectInspector in ```inspectorToDataType``` Author: scwf Closes #2552 from scwf/inspectorToDataType and squashes the following commits: 453d892 [scwf] add case for VoidObjectInspector --- .../main/scala/org/apache/spark/sql/hive/HiveInspectors.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index fa889ec104c6e..d633c42c6bd67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -213,6 +213,8 @@ private[hive] trait HiveInspectors { case _: JavaHiveDecimalObjectInspector => DecimalType case _: WritableTimestampObjectInspector => TimestampType case _: JavaTimestampObjectInspector => TimestampType + case _: WritableVoidObjectInspector => NullType + case _: JavaVoidObjectInspector => NullType } implicit class typeInfoConversions(dt: DataType) { From f84b228c4002073ee4ff53be50462a63f48bd508 Mon Sep 17 00:00:00 2001 From: Venkata Ramana Gollamudi Date: Wed, 1 Oct 2014 15:57:06 -0700 Subject: [PATCH 157/315] [SPARK-3593][SQL] Add support for sorting BinaryType BinaryType is derived from NativeType and added Ordering support. Author: Venkata Ramana G Author: Venkata Ramana Gollamudi Closes #2617 from gvramana/binarytype_sort and squashes the following commits: 1cf26f3 [Venkata Ramana Gollamudi] Supported Sorting of BinaryType --- .../apache/spark/sql/catalyst/types/dataTypes.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 10 ++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index c7d73d3990c3a..ac043d4dd8eb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -157,8 +157,18 @@ case object StringType extends NativeType with PrimitiveType { def simpleString: String = "string" } -case object BinaryType extends DataType with PrimitiveType { +case object BinaryType extends NativeType with PrimitiveType { private[sql] type JvmType = Array[Byte] + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + return x.length - y.length + } + } def simpleString: String = "binary" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 08376eb5e5c4e..fdf3a229a796e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -190,6 +190,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + checkAnswer( + sql("SELECT b FROM binaryData ORDER BY a ASC"), + (1 to 5).map(Row(_)).toSeq) + + checkAnswer( + sql("SELECT b FROM binaryData ORDER BY a DESC"), + (1 to 5).map(Row(_)).toSeq.reverse) + checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), arrayData.collect().sortBy(_.data(0)).toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index eb33a61c6e811..10b7979df7375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -54,6 +54,16 @@ object TestData { TestData2(3, 2) :: Nil) testData2.registerTempTable("testData2") + case class BinaryData(a: Array[Byte], b: Int) + val binaryData: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + BinaryData("12".getBytes(), 1) :: + BinaryData("22".getBytes(), 5) :: + BinaryData("122".getBytes(), 3) :: + BinaryData("121".getBytes(), 2) :: + BinaryData("123".getBytes(), 4) :: Nil) + binaryData.registerTempTable("binaryData") + // TODO: There is no way to express null primitives as case classes currently... val testData3 = logical.LocalRelation('a.int, 'b.int).loadData( From a31f4ff22f98c01f0d9b7d1240080ff2689c1270 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 1 Oct 2014 16:00:29 -0700 Subject: [PATCH 158/315] [SQL] Made Command.sideEffectResult protected Considering `Command.executeCollect()` simply delegates to `Command.sideEffectResult`, we no longer need to leave the latter `protected[sql]`. Author: Cheng Lian Closes #2431 from liancheng/narrow-scope and squashes the following commits: 1bfc16a [Cheng Lian] Made Command.sideEffectResult protected --- .../apache/spark/sql/execution/commands.scala | 10 +++++----- .../org/apache/spark/sql/hive/HiveContext.scala | 2 +- .../sql/hive/execution/CreateTableAsSelect.scala | 16 ++++++++-------- .../execution/DescribeHiveTableCommand.scala | 2 +- .../spark/sql/hive/execution/NativeCommand.scala | 2 +- .../spark/sql/hive/execution/commands.scala | 6 +++--- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index c2f48a902a3e9..f88099ec0761e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -37,7 +37,7 @@ trait Command { * The `execute()` method of all the physical command classes should reference `sideEffectResult` * so that the command can be executed eagerly right after the command query is created. */ - protected[sql] lazy val sideEffectResult: Seq[Row] = Seq.empty[Row] + protected lazy val sideEffectResult: Seq[Row] = Seq.empty[Row] override def executeCollect(): Array[Row] = sideEffectResult.toArray @@ -53,7 +53,7 @@ case class SetCommand( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[Row] = (key, value) match { + override protected lazy val sideEffectResult: Seq[Row] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { @@ -121,7 +121,7 @@ case class ExplainCommand( extends LeafNode with Command { // Run through the optimizer to generate the physical plan. - override protected[sql] lazy val sideEffectResult: Seq[Row] = try { + override protected lazy val sideEffectResult: Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = context.executePlan(logicalPlan) val outputString = if (extended) queryExecution.toString else queryExecution.simpleString @@ -141,7 +141,7 @@ case class ExplainCommand( case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult = { + override protected lazy val sideEffectResult = { if (doCache) { context.cacheTable(tableName) } else { @@ -161,7 +161,7 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { Row("# Registered as a temporary table", null, null) +: child.output.map(field => Row(field.name, field.dataType.toString, null)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 3e1a7b71528e0..20ebe4996c207 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -404,7 +404,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // be similar with Hive. describeHiveTableCommand.hiveString case command: PhysicalCommand => - command.sideEffectResult.map(_.head.toString) + command.executeCollect().map(_.head.toString) case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 1017fe6d5396d..3625708d03175 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -30,23 +30,23 @@ import org.apache.spark.sql.hive.MetastoreRelation * Create table and insert the query result into it. * @param database the database name of the new relation * @param tableName the table name of the new relation - * @param insertIntoRelation function of creating the `InsertIntoHiveTable` + * @param insertIntoRelation function of creating the `InsertIntoHiveTable` * by specifying the `MetaStoreRelation`, the data will be inserted into that table. * TODO Add more table creating properties, e.g. SerDe, StorageHandler, in-memory cache etc. */ @Experimental case class CreateTableAsSelect( - database: String, - tableName: String, - query: SparkPlan, - insertIntoRelation: MetastoreRelation => InsertIntoHiveTable) - extends LeafNode with Command { + database: String, + tableName: String, + query: SparkPlan, + insertIntoRelation: MetastoreRelation => InsertIntoHiveTable) + extends LeafNode with Command { def output = Seq.empty // A lazy computing of the metastoreRelation private[this] lazy val metastoreRelation: MetastoreRelation = { - // Create the table + // Create the table val sc = sqlContext.asInstanceOf[HiveContext] sc.catalog.createTable(database, tableName, query.output, false) // Get the Metastore Relation @@ -55,7 +55,7 @@ case class CreateTableAsSelect( } } - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { insertIntoRelation(metastoreRelation).execute Seq.empty[Row] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 317801001c7a4..106cede9788ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -48,7 +48,7 @@ case class DescribeHiveTableCommand( .mkString("\t") } - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala index 8f10e1ba7f426..6930c2babd117 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala @@ -32,7 +32,7 @@ case class NativeCommand( @transient context: HiveContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_)) + override protected lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_)) override def otherCopyArgs = context :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index d61c5e274a596..0fc674af31885 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -37,7 +37,7 @@ case class AnalyzeTable(tableName: String) extends LeafNode with Command { def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { hiveContext.analyze(tableName) Seq.empty[Row] } @@ -53,7 +53,7 @@ case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { val ifExistsClause = if (ifExists) "IF EXISTS " else "" hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(None, tableName) @@ -70,7 +70,7 @@ case class AddJar(path: String) extends LeafNode with Command { override def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) Seq.empty[Row] From 4e79970d32f9b917590dba8319bdc677e3bdd63a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 1 Oct 2014 16:03:00 -0700 Subject: [PATCH 159/315] Revert "[SPARK-3755][Core] Do not bind port 1 - 1024 to server in spark" This reverts commit 6390aae4eacbabfb1c53fb828b824f6a6518beff. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e5b83c069d961..b3025c6ec3364 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1439,7 +1439,7 @@ private[spark] object Utils extends Logging { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port - val tryPort = if (startPort == 0) startPort else (startPort + offset) % (65536 - 1024) + 1024 + val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536 try { val (service, port) = startService(tryPort) logInfo(s"Successfully started service$serviceString on port $port.") From 45e058ca4babbe3cef6524b6a0f48b466a5139bf Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 1 Oct 2014 16:30:28 -0700 Subject: [PATCH 160/315] [SPARK-3729][SQL] Do all hive session state initialization in lazy val This change avoids a NPE during context initialization when settings are present. Author: Michael Armbrust Closes #2583 from marmbrus/configNPE and squashes the following commits: da2ec57 [Michael Armbrust] Do all hive session state initilialization in lazy val --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 8 ++++---- .../main/scala/org/apache/spark/sql/hive/TestHive.scala | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 20ebe4996c207..fdb56901f9ddb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,12 +231,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + SessionState.start(ss) + ss.err = new PrintStream(outputBuffer, true, "UTF-8") + ss.out = new PrintStream(outputBuffer, true, "UTF-8") + ss } - sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") - sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def setConf(key: String, value: String): Unit = { super.setConf(key, value) runSqlHive(s"SET $key=$value") @@ -273,7 +274,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { results } - SessionState.start(sessionState) /** * Execute the command using Hive and return the results as a sequence. Each element diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 70fb15259e7d7..4a999b98ad92b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -40,8 +40,10 @@ import org.apache.spark.sql.SQLConf /* Implicit conversions */ import scala.collection.JavaConversions._ +// SPARK-3729: Test key required to check for initialization errors with config. object TestHive - extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) + extends TestHiveContext( + new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) /** * A locally running test instance of Spark's Hive execution engine. From 1b9f0d67f28011cdff316042b344c9891f986aaa Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 1 Oct 2014 16:38:10 -0700 Subject: [PATCH 161/315] [SPARK-3704][SQL] Fix ColumnValue type for Short values in thrift server case ```ShortType```, we should add short value to hive row. Int value may lead to some problems. Author: scwf Closes #2551 from scwf/fix-addColumnValue and squashes the following commits: 08bcc59 [scwf] ColumnValue.shortValue for short type --- .../hive/thriftserver/server/SparkSQLOperationManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index bd3f68d92d8c7..910174a153768 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -113,7 +113,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) case ByteType => to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) case ShortType => - to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal))) + to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) case TimestampType => to.addColumnValue( ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) @@ -145,7 +145,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) case ByteType => to.addColumnValue(ColumnValue.byteValue(null)) case ShortType => - to.addColumnValue(ColumnValue.intValue(null)) + to.addColumnValue(ColumnValue.shortValue(null)) case TimestampType => to.addColumnValue(ColumnValue.timestampValue(null)) case BinaryType | _: ArrayType | _: StructType | _: MapType => From 93861a5e876fa57f509cce82768656ddf8d4ef00 Mon Sep 17 00:00:00 2001 From: aniketbhatnagar Date: Wed, 1 Oct 2014 18:31:18 -0700 Subject: [PATCH 162/315] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile This patch forces use of commons http client 4.2 in Kinesis-asl profile so that the AWS SDK does not run into dependency conflicts Author: aniketbhatnagar Closes #2535 from aniketbhatnagar/Kinesis-HttpClient-Dep-Fix and squashes the following commits: aa2079f [aniketbhatnagar] Merge branch 'Kinesis-HttpClient-Dep-Fix' of https://github.com/aniketbhatnagar/spark into Kinesis-HttpClient-Dep-Fix 73f55f6 [aniketbhatnagar] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile 70cc75b [aniketbhatnagar] deleted merge files 725dbc9 [aniketbhatnagar] Merge remote-tracking branch 'origin/Kinesis-HttpClient-Dep-Fix' into Kinesis-HttpClient-Dep-Fix 4ed61d8 [aniketbhatnagar] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile 9cd6103 [aniketbhatnagar] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile --- assembly/pom.xml | 10 ++++++++++ examples/pom.xml | 5 +++++ pom.xml | 1 + 3 files changed, 16 insertions(+) diff --git a/assembly/pom.xml b/assembly/pom.xml index 5ec9da22ae83f..31a01e4d8e1de 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -349,5 +349,15 @@ + + kinesis-asl + + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + diff --git a/examples/pom.xml b/examples/pom.xml index 2b561857f9f33..eb49a0e5af22d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -43,6 +43,11 @@ spark-streaming-kinesis-asl_${scala.binary.version} ${project.version} + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + diff --git a/pom.xml b/pom.xml index 70cb9729ff6d3..7756c89b00cad 100644 --- a/pom.xml +++ b/pom.xml @@ -138,6 +138,7 @@ 0.7.1 1.8.3 1.1.0 + 4.2.6 64m 512m From 29c3513203218af33bea2f6d99d622cf263d55dd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 1 Oct 2014 19:24:22 -0700 Subject: [PATCH 163/315] [SPARK-3446] Expose underlying job ids in FutureAction. FutureAction is the only type exposed through the async APIs, so for job IDs to be useful they need to be exposed there. The complication is that some async jobs run more than one job (e.g. takeAsync), so the exposed ID has to actually be a list of IDs that can actually change over time. So the interface doesn't look very nice, but... Change is actually small, I just added a basic test to make sure it works. Author: Marcelo Vanzin Closes #2337 from vanzin/SPARK-3446 and squashes the following commits: e166a68 [Marcelo Vanzin] Fix comment. 1fed2bc [Marcelo Vanzin] [SPARK-3446] Expose underlying job ids in FutureAction. --- .../scala/org/apache/spark/FutureAction.scala | 19 ++++++- .../org/apache/spark/FutureActionSuite.scala | 49 +++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/FutureActionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 75ea535f2f57b..e8f761eaa5799 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -83,6 +83,15 @@ trait FutureAction[T] extends Future[T] { */ @throws(classOf[Exception]) def get(): T = Await.result(this, Duration.Inf) + + /** + * Returns the job IDs run by the underlying async operation. + * + * This returns the current snapshot of the job list. Certain operations may run multiple + * jobs, so multiple calls to this method may return different lists. + */ + def jobIds: Seq[Int] + } @@ -150,8 +159,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } } - /** Get the corresponding job id for this action. */ - def jobId = jobWaiter.jobId + def jobIds = Seq(jobWaiter.jobId) } @@ -171,6 +179,8 @@ class ComplexFutureAction[T] extends FutureAction[T] { // is cancelled before the action was even run (and thus we have no thread to interrupt). @volatile private var _cancelled: Boolean = false + @volatile private var jobs: Seq[Int] = Nil + // A promise used to signal the future. private val p = promise[T]() @@ -219,6 +229,8 @@ class ComplexFutureAction[T] extends FutureAction[T] { } } + this.jobs = jobs ++ job.jobIds + // Wait for the job to complete. If the action is cancelled (with an interrupt), // cancel the job and stop the execution. This is not in a synchronized block because // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. @@ -255,4 +267,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def isCompleted: Boolean = p.isCompleted override def value: Option[Try[T]] = p.future.value + + def jobIds = jobs + } diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala new file mode 100644 index 0000000000000..db9c25fc457a4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} + +import org.apache.spark.SparkContext._ + +class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext { + + before { + sc = new SparkContext("local", "FutureActionSuite") + } + + test("simple async action") { + val rdd = sc.parallelize(1 to 10, 2) + val job = rdd.countAsync() + val res = Await.result(job, Duration.Inf) + res should be (10) + job.jobIds.size should be (1) + } + + test("complex async action") { + val rdd = sc.parallelize(1 to 15, 3) + val job = rdd.takeAsync(10) + val res = Await.result(job, Duration.Inf) + res should be (1 to 10) + job.jobIds.size should be (2) + } + +} From f341e1c8a284b55cceb367a432c1fa5203692155 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 1 Oct 2014 23:08:51 -0700 Subject: [PATCH 164/315] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1375 (close requested by 'pwendell') Closes #476 (close requested by 'mengxr') Closes #2502 (close requested by 'pwendell') Closes #2391 (close requested by 'andrewor14') From bbdf1de84ffdd3bd172f17975d2f1422a9bcf2c6 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Wed, 1 Oct 2014 23:53:21 -0700 Subject: [PATCH 165/315] [SPARK-3371][SQL] Renaming a function expression with group by gives error The following code gives error. ``` sqlContext.registerFunction("len", (s: String) => s.length) sqlContext.sql("select len(foo) as a, count(1) from t1 group by len(foo)").collect() ``` Because SQl parser creates the aliases to the functions in grouping expressions with generated alias names. So if user gives the alias names to the functions inside projection then it does not match the generated alias name of grouping expression. This kind of queries are working in Hive. So the fix I have given that if user provides alias to the function in projection then don't generate alias in grouping expression,use the same alias. Author: ravipesala Closes #2511 from ravipesala/SPARK-3371 and squashes the following commits: 9fb973f [ravipesala] Removed aliases to grouping expressions. f8ace79 [ravipesala] Fixed the testcase issue bad2fd0 [ravipesala] SPARK-3371 : Fixed Renaming a function expression with group by gives error --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 862f78702c4e6..26336332c05a2 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -166,7 +166,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { val withFilter = f.map(f => Filter(f, base)).getOrElse(base) val withProjection = g.map {g => - Aggregate(assignAliases(g), assignAliases(p), withFilter) + Aggregate(g, assignAliases(p), withFilter) }.getOrElse(Project(assignAliases(p), withFilter)) val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index fdf3a229a796e..6fb6cb8db0c8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -680,4 +680,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } + + test("SPARK-3371 Renaming a function expression with group by gives error") { + registerFunction("len", (s: String) => s.length) + checkAnswer( + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} } From 6e27cb630de69fa5acb510b4e2f6b980742b1957 Mon Sep 17 00:00:00 2001 From: Colin Patrick Mccabe Date: Thu, 2 Oct 2014 00:29:31 -0700 Subject: [PATCH 166/315] SPARK-1767: Prefer HDFS-cached replicas when scheduling data-local tasks This change reorders the replicas returned by HadoopRDD#getPreferredLocations so that replicas cached by HDFS are at the start of the list. This requires Hadoop 2.5 or higher; previous versions of Hadoop do not expose the information needed to determine whether a replica is cached. Author: Colin Patrick Mccabe Closes #1486 from cmccabe/SPARK-1767 and squashes the following commits: 338d4f8 [Colin Patrick Mccabe] SPARK-1767: Prefer HDFS-cached replicas when scheduling data-local tasks --- .../org/apache/spark/rdd/HadoopRDD.scala | 60 +++++++++++++++++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 18 +++++- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../apache/spark/scheduler/TaskLocation.scala | 48 +++++++++++++-- .../spark/scheduler/TaskSetManager.scala | 25 +++++++- .../spark/scheduler/TaskSetManagerSuite.scala | 22 +++++++ project/MimaExcludes.scala | 2 + 8 files changed, 162 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 21d0cc7b5cbaa..6b63eb23e9ee1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -23,6 +23,7 @@ import java.io.EOFException import scala.collection.immutable.Map import scala.reflect.ClassTag +import scala.collection.mutable.ListBuffer import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit @@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataReadMethod, InputMetrics} import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} /** @@ -249,9 +251,21 @@ class HadoopRDD[K, V]( } override def getPreferredLocations(split: Partition): Seq[String] = { - // TODO: Filtering out "localhost" in case of file:// URLs - val hadoopSplit = split.asInstanceOf[HadoopPartition] - hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") + val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value + val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val lsplit = c.inputSplitWithLocationInfo.cast(hsplit) + val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e: Exception => + logDebug("Failed to use InputSplitWithLocations.", e) + None + } + case None => None + } + locs.getOrElse(hsplit.getLocations.filter(_ != "localhost")) } override def checkpoint() { @@ -261,7 +275,7 @@ class HadoopRDD[K, V]( def getConf: Configuration = getJobConf() } -private[spark] object HadoopRDD { +private[spark] object HadoopRDD extends Logging { /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ val CONFIGURATION_INSTANTIATION_LOCK = new Object() @@ -309,4 +323,42 @@ private[spark] object HadoopRDD { f(inputSplit, firstParent[T].iterator(split, context)) } } + + private[spark] class SplitInfoReflections { + val inputSplitWithLocationInfo = + Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") + val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") + val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val isInMemory = splitLocationInfo.getMethod("isInMemory") + val getLocation = splitLocationInfo.getMethod("getLocation") + } + + private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try { + Some(new SplitInfoReflections) + } catch { + case e: Exception => + logDebug("SplitLocationInfo and other new Hadoop classes are " + + "unavailable. Using the older Hadoop location info code.", e) + None + } + + private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { + val out = ListBuffer[String]() + infos.foreach { loc => { + val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. + getLocation.invoke(loc).asInstanceOf[String] + if (locationStr != "localhost") { + if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory. + invoke(loc).asInstanceOf[Boolean]) { + logDebug("Partition " + locationStr + " is cached by Hadoop.") + out += new HDFSCacheTaskLocation(locationStr).toString + } else { + out += new HostTaskLocation(locationStr).toString + } + } + }} + out.seq + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 4c84b3f62354d..0cccdefc5ee09 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -173,9 +173,21 @@ class NewHadoopRDD[K, V]( new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) } - override def getPreferredLocations(split: Partition): Seq[String] = { - val theSplit = split.asInstanceOf[NewHadoopPartition] - theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") + override def getPreferredLocations(hsplit: Partition): Seq[String] = { + val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value + val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e : Exception => + logDebug("Failed to use InputSplit#getLocationInfo.", e) + None + } + case None => None + } + locs.getOrElse(split.getLocations.filter(_ != "localhost")) } def getConf: Configuration = confBroadcast.value.value diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ab9e97c8fe409..2aba40d152e3e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -208,7 +208,7 @@ abstract class RDD[T: ClassTag]( } /** - * Get the preferred locations of a partition (as hostnames), taking into account whether the + * Get the preferred locations of a partition, taking into account whether the * RDD is checkpointed. */ final def preferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5a96f52a10cd4..8135cdbb4c31f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1303,7 +1303,7 @@ class DAGScheduler( // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList if (!rddPrefs.isEmpty) { - return rddPrefs.map(host => TaskLocation(host)) + return rddPrefs.map(TaskLocation(_)) } // If the RDD has narrow dependencies, pick the first partition of the first narrow dep // that has any placement preferences. Ideally we would choose based on transfer sizes, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 67c9a6760b1b3..10c685f29d3ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -22,13 +22,51 @@ package org.apache.spark.scheduler * In the latter case, we will prefer to launch the task on that executorID, but our next level * of preference will be executors on the same host if this is not possible. */ -private[spark] -class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable { - override def toString: String = "TaskLocation(" + host + ", " + executorId + ")" +private[spark] sealed trait TaskLocation { + def host: String +} + +/** + * A location that includes both a host and an executor id on that host. + */ +private [spark] case class ExecutorCacheTaskLocation(override val host: String, + val executorId: String) extends TaskLocation { +} + +/** + * A location on a host. + */ +private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation { + override def toString = host +} + +/** + * A location on a host that is cached by HDFS. + */ +private [spark] case class HDFSCacheTaskLocation(override val host: String) + extends TaskLocation { + override def toString = TaskLocation.inMemoryLocationTag + host } private[spark] object TaskLocation { - def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId)) + // We identify hosts on which the block is cached with this prefix. Because this prefix contains + // underscores, which are not legal characters in hostnames, there should be no potential for + // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. + val inMemoryLocationTag = "hdfs_cache_" + + def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId) - def apply(host: String) = new TaskLocation(host, None) + /** + * Create a TaskLocation from a string returned by getPreferredLocations. + * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the + * location is cached. + */ + def apply(str: String) = { + val hstr = str.stripPrefix(inMemoryLocationTag) + if (hstr.equals(str)) { + new HostTaskLocation(str) + } else { + new HostTaskLocation(hstr) + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d9d53faf843ff..a6c23fc85a1b0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -181,8 +181,24 @@ private[spark] class TaskSetManager( } for (loc <- tasks(index).preferredLocations) { - for (execId <- loc.executorId) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) + loc match { + case e: ExecutorCacheTaskLocation => + addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer)) + case e: HDFSCacheTaskLocation => { + val exe = sched.getExecutorsAliveOnHost(loc.host) + exe match { + case Some(set) => { + for (e <- set) { + addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer)) + } + logInfo(s"Pending task $index has a cached location at ${e.host} " + + ", where there are executors " + set.mkString(",")) + } + case None => logDebug(s"Pending task $index has a cached location at ${e.host} " + + ", but there are no executors alive there.") + } + } + case _ => Unit } addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) for (rack <- sched.getRackForHost(loc.host)) { @@ -283,7 +299,10 @@ private[spark] class TaskSetManager( // on multiple nodes when we replicate cached blocks, as in Spark Streaming for (index <- speculatableTasks if canRunOnHost(index)) { val prefs = tasks(index).preferredLocations - val executors = prefs.flatMap(_.executorId) + val executors = prefs.flatMap(_ match { + case e: ExecutorCacheTaskLocation => Some(e.executorId) + case _ => None + }); if (executors.contains(execId)) { speculatableTasks -= index return Some((index, TaskLocality.PROCESS_LOCAL)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 93e8ddacf8865..c0b07649eb6dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -642,6 +642,28 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execC", "host3", ANY) !== None) } + test("Test that locations with HDFSCacheTaskLocation are treated as PROCESS_LOCAL.") { + // Regression test for SPARK-2931 + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, + ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(3, + Seq(HostTaskLocation("host1")), + Seq(HostTaskLocation("host2")), + Seq(HDFSCacheTaskLocation("host3"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + sched.removeExecutor("execA") + manager.executorAdded() + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + sched.removeExecutor("execB") + manager.executorAdded() + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + sched.removeExecutor("execC") + manager.executorAdded() + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + } def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4076ebc6fc8d5..d499302124461 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -41,6 +41,8 @@ object MimaExcludes { MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ Seq( + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.TaskLocation"), // Added normL1 and normL2 to trait MultivariateStatisticalSummary ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), From 5b4a5b1acdc439a58aa2a3561ac0e3fb09f529d6 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Thu, 2 Oct 2014 11:13:19 -0700 Subject: [PATCH 167/315] [SPARK-3706][PySpark] Cannot run IPython REPL with IPYTHON set to "1" and PYSPARK_PYTHON unset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Problem The section "Using the shell" in Spark Programming Guide (https://spark.apache.org/docs/latest/programming-guide.html#using-the-shell) says that we can run pyspark REPL through IPython. But a folloing command does not run IPython but a default Python executable. ``` $ IPYTHON=1 ./bin/pyspark Python 2.7.8 (default, Jul 2 2014, 10:14:46) ... ``` the spark/bin/pyspark script on the commit b235e013638685758885842dc3268e9800af3678 decides which executable and options it use folloing way. 1. if PYSPARK_PYTHON unset * → defaulting to "python" 2. if IPYTHON_OPTS set * → set IPYTHON "1" 3. some python scripts passed to ./bin/pyspak → run it with ./bin/spark-submit * out of this issues scope 4. if IPYTHON set as "1" * → execute $PYSPARK_PYTHON (default: ipython) with arguments $IPYTHON_OPTS * otherwise execute $PYSPARK_PYTHON Therefore, when PYSPARK_PYTHON is unset, python is executed though IPYTHON is "1". In other word, when PYSPARK_PYTHON is unset, IPYTHON_OPS and IPYTHON has no effect on decide which command to use. PYSPARK_PYTHON | IPYTHON_OPTS | IPYTHON | resulting command | expected command ---- | ---- | ----- | ----- | ----- (unset → defaults to python) | (unset) | (unset) | python | (same) (unset → defaults to python) | (unset) | 1 | python | ipython (unset → defaults to python) | an_option | (unset → set to 1) | python an_option | ipython an_option (unset → defaults to python) | an_option | 1 | python an_option | ipython an_option ipython | (unset) | (unset) | ipython | (same) ipython | (unset) | 1 | ipython | (same) ipython | an_option | (unset → set to 1) | ipython an_option | (same) ipython | an_option | 1 | ipython an_option | (same) ### Suggestion The pyspark script should determine firstly whether a user wants to run IPython or other executables. 1. if IPYTHON_OPTS set * set IPYTHON "1" 2. if IPYTHON has a value "1" * PYSPARK_PYTHON defaults to "ipython" if not set 3. PYSPARK_PYTHON defaults to "python" if not set See the pull request for more detailed modification. Author: cocoatomo Closes #2554 from cocoatomo/issues/cannot-run-ipython-without-options and squashes the following commits: d2a9b06 [cocoatomo] [SPARK-3706][PySpark] Use PYTHONUNBUFFERED environment variable instead of -u option 264114c [cocoatomo] [SPARK-3706][PySpark] Remove the sentence about deprecated environment variables 42e02d5 [cocoatomo] [SPARK-3706][PySpark] Replace environment variables used to customize execution of PySpark REPL 10d56fb [cocoatomo] [SPARK-3706][PySpark] Cannot run IPython REPL with IPYTHON set to "1" and PYSPARK_PYTHON unset --- bin/pyspark | 24 +++++++++---------- .../apache/spark/deploy/PythonRunner.scala | 3 ++- docs/programming-guide.md | 8 +++---- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 5142411e36974..6655725ef8e8e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -52,10 +52,20 @@ fi # Figure out which Python executable to use if [[ -z "$PYSPARK_PYTHON" ]]; then - PYSPARK_PYTHON="python" + if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then + # for backward compatibility + PYSPARK_PYTHON="ipython" + else + PYSPARK_PYTHON="python" + fi fi export PYSPARK_PYTHON +if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then + # for backward compatibility + PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS" +fi + # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" @@ -64,11 +74,6 @@ export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py" -# If IPython options are specified, assume user wants to run IPython -if [[ -n "$IPYTHON_OPTS" ]]; then - IPYTHON=1 -fi - # Build up arguments list manually to preserve quotes and backslashes. # We export Spark submit arguments as an environment variable because shell.py must run as a # PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. @@ -106,10 +111,5 @@ if [[ "$1" =~ \.py$ ]]; then else # PySpark shell requires special handling downstream export PYSPARK_SHELL=1 - # Only use ipython if no command line arguments were provided [SPARK-1134] - if [[ "$IPYTHON" = "1" ]]; then - exec ${PYSPARK_PYTHON:-ipython} $IPYTHON_OPTS - else - exec "$PYSPARK_PYTHON" - fi + exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS fi diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index b66c3ba4d5fb0..79b4d7ea41a33 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -54,9 +54,10 @@ object PythonRunner { val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) // Launch Python process - val builder = new ProcessBuilder(Seq(pythonExec, "-u", formattedPythonFile) ++ otherArgs) + val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 1d61a3c555eaf..8e8cc1dd983f8 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes, It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To -use IPython, set the `IPYTHON` variable to `1` when running `bin/pyspark`: +use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`: {% highlight bash %} -$ IPYTHON=1 ./bin/pyspark +$ PYSPARK_PYTHON=ipython ./bin/pyspark {% endhighlight %} -You can customize the `ipython` command by setting `IPYTHON_OPTS`. For example, to launch +You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ IPYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark {% endhighlight %} From 82a6a083a485140858bcd93d73adec59bb5cca64 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 2 Oct 2014 11:37:24 -0700 Subject: [PATCH 168/315] [SQL][Docs] Update the output of printSchema and fix a typo in SQL programming guide. We have changed the output format of `printSchema`. This PR will update our SQL programming guide to show the updated format. Also, it fixes a typo (the value type of `StructType` in Java API). Author: Yin Huai Closes #2630 from yhuai/sqlDoc and squashes the following commits: 267d63e [Yin Huai] Update the output of printSchema and fix a typo. --- docs/sql-programming-guide.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 818fd5ab80af8..368c3d0008b07 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -620,8 +620,8 @@ val people = sqlContext.jsonFile(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() // root -// |-- age: IntegerType -// |-- name: StringType +// |-- age: integer (nullable = true) +// |-- name: string (nullable = true) // Register this SchemaRDD as a table. people.registerTempTable("people") @@ -658,8 +658,8 @@ JavaSchemaRDD people = sqlContext.jsonFile(path); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); // root -// |-- age: IntegerType -// |-- name: StringType +// |-- age: integer (nullable = true) +// |-- name: string (nullable = true) // Register this JavaSchemaRDD as a table. people.registerTempTable("people"); @@ -697,8 +697,8 @@ people = sqlContext.jsonFile(path) # The inferred schema can be visualized using the printSchema() method. people.printSchema() # root -# |-- age: IntegerType -# |-- name: StringType +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) # Register this SchemaRDD as a table. people.registerTempTable("people") @@ -1394,7 +1394,7 @@ please use factory methods provided in - + - + - + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 26dbd6237c6b8..a12f82d2fbe70 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf import org.apache.spark.util.{Utils, IntParam, MemoryParam} - +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { @@ -39,15 +39,17 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var appName: String = "Spark" var priority = 0 + parseArgs(args.toList) + loadEnvironmentArgs() + // Additional memory to allocate to containers // For now, use driver's memory overhead as our AM container's memory overhead - val amMemoryOverhead = sparkConf.getInt( - "spark.yarn.driver.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - val executorMemoryOverhead = sparkConf.getInt( - "spark.yarn.executor.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) + val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", + math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN)) + + val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) - parseArgs(args.toList) - loadEnvironmentArgs() validateArgs() /** Load any default arguments provided through environment variables and Spark properties. */ diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 1cf19c198509c..6ecac6eae6e03 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -64,14 +64,18 @@ private[spark] trait ClientBase extends Logging { s"memory capability of the cluster ($maxMem MB per container)") val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory ($executorMem MB) " + - s"is above the max threshold ($maxMem MB) of this cluster!") + throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory ($amMem MB) " + - s"is above the max threshold ($maxMem MB) of this cluster!") + throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } + logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( + amMem, + amMemoryOverhead)) + // We could add checks to make sure the entire cluster has enough resources but that involves // getting all the node reports and computing ourselves. } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 299e38a5eb9c0..4f4f1d2aaaade 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ object AllocationType extends Enumeration { type AllocationType = Value @@ -78,10 +79,6 @@ private[yarn] abstract class YarnAllocator( // Containers to be released in next request to RM private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean] - // Additional memory overhead - in mb. - protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - // Number of container requests that have been sent to, but not yet allocated by the // ApplicationMaster. private val numPendingAllocate = new AtomicInteger() @@ -97,6 +94,10 @@ private[yarn] abstract class YarnAllocator( protected val (preferredHostToCount, preferredRackToCount) = generateNodeToWeight(conf, preferredNodes) + // Additional memory overhead - in mb. + protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) + private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, @@ -114,12 +115,11 @@ private[yarn] abstract class YarnAllocator( // this is needed by alpha, do it here since we add numPending right after this val executorsPending = numPendingAllocate.get() - if (missing > 0) { + val totalExecutorMemory = executorMemory + memoryOverhead numPendingAllocate.addAndGet(missing) - logInfo("Will Allocate %d executor containers, each with %d memory".format( - missing, - (executorMemory + memoryOverhead))) + logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " + + s"memory including $memoryOverhead MB overhead") } else { logDebug("Empty allocation request ...") } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 0b712c201904a..e1e0144f46fe9 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -84,8 +84,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } object YarnSparkHadoopUtil { - // Additional memory overhead - in mb. - val DEFAULT_MEMORY_OVERHEAD = 384 + // Additional memory overhead + // 7% was arrived at experimentally. In the interest of minimizing memory waste while covering + // the common cases. Memory overhead tends to grow with container size. + + val MEMORY_OVERHEAD_FACTOR = 0.07 + val MEMORY_OVERHEAD_MIN = 384 val ANY_HOST = "*" From c6469a02f14e8c23e9b4e1336768f8bbfc15f5d8 Mon Sep 17 00:00:00 2001 From: scwf Date: Thu, 2 Oct 2014 13:47:30 -0700 Subject: [PATCH 170/315] [SPARK-3766][Doc]Snappy is also the default compress codec for broadcast variables Author: scwf Closes #2632 from scwf/compress-doc and squashes the following commits: 7983a1a [scwf] snappy is the default compression codec for broadcast --- docs/configuration.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 791b6f2aa3261..316490f0f43fc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -413,10 +413,11 @@ Apart from these, the following properties are also available, and may be useful From 5db78e6b87d33ac2d48a997e69b46e9be3b63137 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 2 Oct 2014 13:49:47 -0700 Subject: [PATCH 171/315] [SPARK-3495] Block replication fails continuously when the replication target node is dead AND [SPARK-3496] Block replication by mistake chooses driver as target If a block manager (say, A) wants to replicate a block and the node chosen for replication (say, B) is dead, then the attempt to send the block to B fails. However, this continues to fail indefinitely. Even if the driver learns about the demise of the B, A continues to try replicating to B and failing miserably. The reason behind this bug is that A initially fetches a list of peers from the driver (when B was active), but never updates it after B is dead. This affects Spark Streaming as its receiver uses block replication. The solution in this patch adds the following. - Changed BlockManagerMaster to return all the peers of a block manager, rather than the requested number. It also filters out driver BlockManager. - Refactored BlockManager's replication code to handle peer caching correctly. + The peer for replication is randomly selected. This is different from past behavior where for a node A, a node B was deterministically chosen for the lifetime of the application. + If replication fails to one node, the peers are refetched. + The peer cached has a TTL of 1 second to enable discovery of new peers and using them for replication. - Refactored use of \ in BlockManager into a new method `BlockManagerId.isDriver` - Added replication unit tests (replication was not tested till now, duh!) This should not make a difference in performance of Spark workloads where replication is not used. @andrewor14 @JoshRosen Author: Tathagata Das Closes #2366 from tdas/replication-fix and squashes the following commits: 9690f57 [Tathagata Das] Moved replication tests to a new BlockManagerReplicationSuite. 0661773 [Tathagata Das] Minor changes based on PR comments. a55a65c [Tathagata Das] Added a unit test to test replication behavior. 012afa3 [Tathagata Das] Bug fix 89f91a0 [Tathagata Das] Minor change. 68e2c72 [Tathagata Das] Made replication peer selection logic more efficient. 08afaa9 [Tathagata Das] Made peer selection for replication deterministic to block id 3821ab9 [Tathagata Das] Fixes based on PR comments. 08e5646 [Tathagata Das] More minor changes. d402506 [Tathagata Das] Fixed imports. 4a20531 [Tathagata Das] Filtered driver block manager from peer list, and also consolidated the use of in BlockManager. 7598f91 [Tathagata Das] Minor changes. 03de02d [Tathagata Das] Change replication logic to correctly refetch peers from master on failure and on new worker addition. d081bf6 [Tathagata Das] Fixed bug in get peers and unit tests to test get-peers and replication under executor churn. 9f0ac9f [Tathagata Das] Modified replication tests to fail on replication bug. af0c1da [Tathagata Das] Added replication unit tests to BlockManagerSuite --- .../apache/spark/storage/BlockManager.scala | 122 ++++- .../apache/spark/storage/BlockManagerId.scala | 2 + .../spark/storage/BlockManagerMaster.scala | 9 +- .../storage/BlockManagerMasterActor.scala | 29 +- .../spark/storage/BlockManagerMessages.scala | 2 +- .../spark/broadcast/BroadcastSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 418 ++++++++++++++++++ .../spark/storage/BlockManagerSuite.scala | 9 +- 8 files changed, 544 insertions(+), 49 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d1bee3d2c033c..3f5d06e1aeee7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.concurrent.ExecutionContext.Implicits.global +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ @@ -112,6 +113,11 @@ private[spark] class BlockManager( private val broadcastCleaner = new MetadataCleaner( MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf) + // Field related to peer block managers that are necessary for block replication + @volatile private var cachedPeers: Seq[BlockManagerId] = _ + private val peerFetchLock = new Object + private var lastPeerFetchTime = 0L + initialize() /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay @@ -787,31 +793,111 @@ private[spark] class BlockManager( } /** - * Replicate block to another node. + * Get peer block managers in the system. + */ + private def getPeers(forceFetch: Boolean): Seq[BlockManagerId] = { + peerFetchLock.synchronized { + val cachedPeersTtl = conf.getInt("spark.storage.cachedPeersTtl", 60 * 1000) // milliseconds + val timeout = System.currentTimeMillis - lastPeerFetchTime > cachedPeersTtl + if (cachedPeers == null || forceFetch || timeout) { + cachedPeers = master.getPeers(blockManagerId).sortBy(_.hashCode) + lastPeerFetchTime = System.currentTimeMillis + logDebug("Fetched peers from master: " + cachedPeers.mkString("[", ",", "]")) + } + cachedPeers + } + } + + /** + * Replicate block to another node. Not that this is a blocking call that returns after + * the block has been replicated. */ - @volatile var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = { + val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) + val numPeersToReplicateTo = level.replication - 1 + val peersForReplication = new ArrayBuffer[BlockManagerId] + val peersReplicatedTo = new ArrayBuffer[BlockManagerId] + val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId] val tLevel = StorageLevel( level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1) - if (cachedPeers == null) { - cachedPeers = master.getPeers(blockManagerId, level.replication - 1) + val startTime = System.currentTimeMillis + val random = new Random(blockId.hashCode) + + var replicationFailed = false + var failures = 0 + var done = false + + // Get cached list of peers + peersForReplication ++= getPeers(forceFetch = false) + + // Get a random peer. Note that this selection of a peer is deterministic on the block id. + // So assuming the list of peers does not change and no replication failures, + // if there are multiple attempts in the same node to replicate the same block, + // the same set of peers will be selected. + def getRandomPeer(): Option[BlockManagerId] = { + // If replication had failed, then force update the cached list of peers and remove the peers + // that have been already used + if (replicationFailed) { + peersForReplication.clear() + peersForReplication ++= getPeers(forceFetch = true) + peersForReplication --= peersReplicatedTo + peersForReplication --= peersFailedToReplicateTo + } + if (!peersForReplication.isEmpty) { + Some(peersForReplication(random.nextInt(peersForReplication.size))) + } else { + None + } } - for (peer: BlockManagerId <- cachedPeers) { - val start = System.nanoTime - data.rewind() - logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " + - s"To node: $peer") - try { - blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) - } catch { - case e: Exception => - logError(s"Failed to replicate block to $peer", e) + // One by one choose a random peer and try uploading the block to it + // If replication fails (e.g., target peer is down), force the list of cached peers + // to be re-fetched from driver and then pick another random peer for replication. Also + // temporarily black list the peer for which replication failed. + // + // This selection of a peer and replication is continued in a loop until one of the + // following 3 conditions is fulfilled: + // (i) specified number of peers have been replicated to + // (ii) too many failures in replicating to peers + // (iii) no peer left to replicate to + // + while (!done) { + getRandomPeer() match { + case Some(peer) => + try { + val onePeerStartTime = System.currentTimeMillis + data.rewind() + logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") + blockTransferService.uploadBlockSync( + peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms" + .format((System.currentTimeMillis - onePeerStartTime))) + peersReplicatedTo += peer + peersForReplication -= peer + replicationFailed = false + if (peersReplicatedTo.size == numPeersToReplicateTo) { + done = true // specified number of peers have been replicated to + } + } catch { + case e: Exception => + logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e) + failures += 1 + replicationFailed = true + peersFailedToReplicateTo += peer + if (failures > maxReplicationFailures) { // too many failures in replcating to peers + done = true + } + } + case None => // no peer left to replicate to + done = true } - - logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes." - .format(blockId, (System.nanoTime - start) / 1e6, data.limit())) + } + val timeTakeMs = (System.currentTimeMillis - startTime) + logDebug(s"Replicating $blockId of ${data.limit()} bytes to " + + s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms") + if (peersReplicatedTo.size < numPeersToReplicateTo) { + logWarning(s"Block $blockId replicated to only " + + s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index d4487fce49ab6..142285094342c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -59,6 +59,8 @@ class BlockManagerId private ( def port: Int = port_ + def isDriver: Boolean = (executorId == "") + override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) out.writeUTF(host_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 2e262594b3538..d08e1419e3e41 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -84,13 +84,8 @@ class BlockManagerMaster( } /** Get ids of other nodes in the cluster from the driver */ - def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { - val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) - if (result.length != numPeers) { - throw new SparkException( - "Error getting peers, only got " + result.size + " instead of " + numPeers) - } - result + def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { + askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 1a6c7cb24f9ac..6a06257ed0c08 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -83,8 +83,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetLocationsMultipleBlockIds(blockIds) => sender ! getLocationsMultipleBlockIds(blockIds) - case GetPeers(blockManagerId, size) => - sender ! getPeers(blockManagerId, size) + case GetPeers(blockManagerId) => + sender ! getPeers(blockManagerId) case GetMemoryStatus => sender ! memoryStatus @@ -173,11 +173,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - // TODO: Consolidate usages of import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => - removeFromDriver || info.blockManagerId.executorId != "" + removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => @@ -212,7 +211,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val minSeenTime = now - slaveTimeout val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") { + if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") toRemove += info.blockManagerId @@ -232,7 +231,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus */ private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.executorId == "" && !isLocal + blockManagerId.isDriver && !isLocal } else { blockManagerInfo(blockManagerId).updateLastSeenMs() true @@ -355,7 +354,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus tachyonSize: Long) { if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.executorId == "" && !isLocal) { + if (blockManagerId.isDriver && !isLocal) { // We intentionally do not register the master (except in local mode), // so we should not indicate failure. sender ! true @@ -403,16 +402,14 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockIds.map(blockId => getLocations(blockId)) } - private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = { - val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - - val selfIndex = peers.indexOf(blockManagerId) - if (selfIndex == -1) { - throw new SparkException("Self index for " + blockManagerId + " not found") + /** Get the list of the peers of the given block manager */ + private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { + val blockManagerIds = blockManagerInfo.keySet + if (blockManagerIds.contains(blockManagerId)) { + blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq + } else { + Seq.empty } - - // Note that this logic will select the same node multiple times if there aren't enough peers - Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 2ba16b8476600..3db5dd9774ae8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -88,7 +88,7 @@ private[spark] object BlockManagerMessages { case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster - case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster + case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 978a6ded80829..acaf321de52fb 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -132,7 +132,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => - assert(bm.executorId === "", "Block should only be on the driver") + assert(bm.isDriver, "Block should only be on the driver") assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) assert(status.memSize > 0, "Block should be in memory store on the driver") assert(status.diskSize === 0, "Block should not be in disk store on the driver") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala new file mode 100644 index 0000000000000..1f1d53a1ee3b0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + +import akka.actor.{ActorSystem, Props} +import org.mockito.Mockito.{mock, when} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.{AkkaUtils, SizeEstimator} + +/** Testsuite that tests block replication in BlockManager */ +class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { + + private val conf = new SparkConf(false) + var actorSystem: ActorSystem = null + var master: BlockManagerMaster = null + val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) + + // List of block manager created during an unit test, so that all of the them can be stopped + // after the unit test. + val allStores = new ArrayBuffer[BlockManager] + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + conf.set("spark.kryoserializer.buffer.mb", "1") + val serializer = new KryoSerializer(conf) + + // Implicitly convert strings to BlockIds for test clarity. + implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + + private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer) + allStores += store + store + } + + before { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "test", "localhost", 0, conf = conf, securityManager = securityMgr) + this.actorSystem = actorSystem + + conf.set("spark.authenticate", "false") + conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.storage.unrollFraction", "0.4") + conf.set("spark.storage.unrollMemoryThreshold", "512") + + // to make a replication attempt to inactive store fail fast + conf.set("spark.core.connection.ack.wait.timeout", "1") + // to make cached peers refresh frequently + conf.set("spark.storage.cachedPeersTtl", "10") + + master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + conf, true) + allStores.clear() + } + + after { + allStores.foreach { _.stop() } + allStores.clear() + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null + master = null + } + + + test("get peers with addition and removal of block managers") { + val numStores = 4 + val stores = (1 to numStores - 1).map { i => makeBlockManager(1000, s"store$i") } + val storeIds = stores.map { _.blockManagerId }.toSet + assert(master.getPeers(stores(0).blockManagerId).toSet === + storeIds.filterNot { _ == stores(0).blockManagerId }) + assert(master.getPeers(stores(1).blockManagerId).toSet === + storeIds.filterNot { _ == stores(1).blockManagerId }) + assert(master.getPeers(stores(2).blockManagerId).toSet === + storeIds.filterNot { _ == stores(2).blockManagerId }) + + // Add driver store and test whether it is filtered out + val driverStore = makeBlockManager(1000, "") + assert(master.getPeers(stores(0).blockManagerId).forall(!_.isDriver)) + assert(master.getPeers(stores(1).blockManagerId).forall(!_.isDriver)) + assert(master.getPeers(stores(2).blockManagerId).forall(!_.isDriver)) + + // Add a new store and test whether get peers returns it + val newStore = makeBlockManager(1000, s"store$numStores") + assert(master.getPeers(stores(0).blockManagerId).toSet === + storeIds.filterNot { _ == stores(0).blockManagerId } + newStore.blockManagerId) + assert(master.getPeers(stores(1).blockManagerId).toSet === + storeIds.filterNot { _ == stores(1).blockManagerId } + newStore.blockManagerId) + assert(master.getPeers(stores(2).blockManagerId).toSet === + storeIds.filterNot { _ == stores(2).blockManagerId } + newStore.blockManagerId) + assert(master.getPeers(newStore.blockManagerId).toSet === storeIds) + + // Remove a store and test whether get peers returns it + val storeIdToRemove = stores(0).blockManagerId + master.removeExecutor(storeIdToRemove.executorId) + assert(!master.getPeers(stores(1).blockManagerId).contains(storeIdToRemove)) + assert(!master.getPeers(stores(2).blockManagerId).contains(storeIdToRemove)) + assert(!master.getPeers(newStore.blockManagerId).contains(storeIdToRemove)) + + // Test whether asking for peers of a unregistered block manager id returns empty list + assert(master.getPeers(stores(0).blockManagerId).isEmpty) + assert(master.getPeers(BlockManagerId("", "", 1)).isEmpty) + } + + + test("block replication - 2x replication") { + testReplication(2, + Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK_2, MEMORY_AND_DISK_SER_2) + ) + } + + test("block replication - 3x replication") { + // Generate storage levels with 3x replication + val storageLevels = { + Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK, MEMORY_AND_DISK_SER).map { + level => StorageLevel( + level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 3) + } + } + testReplication(3, storageLevels) + } + + test("block replication - mixed between 1x to 5x") { + // Generate storage levels with varying replication + val storageLevels = Seq( + MEMORY_ONLY, + MEMORY_ONLY_SER_2, + StorageLevel(true, false, false, false, 3), + StorageLevel(true, true, false, true, 4), + StorageLevel(true, true, false, false, 5), + StorageLevel(true, true, false, true, 4), + StorageLevel(true, false, false, false, 3), + MEMORY_ONLY_SER_2, + MEMORY_ONLY + ) + testReplication(5, storageLevels) + } + + test("block replication - 2x replication without peers") { + intercept[org.scalatest.exceptions.TestFailedException] { + testReplication(1, + Seq(StorageLevel.MEMORY_AND_DISK_2, StorageLevel(true, false, false, false, 3))) + } + } + + test("block replication - deterministic node selection") { + val blockSize = 1000 + val storeSize = 10000 + val stores = (1 to 5).map { + i => makeBlockManager(storeSize, s"store$i") + } + val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2 + val storageLevel3x = StorageLevel(true, true, false, true, 3) + val storageLevel4x = StorageLevel(true, true, false, true, 4) + + def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { + stores.head.putSingle(blockId, new Array[Byte](blockSize), level) + val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet + stores.foreach { _.removeBlock(blockId) } + master.removeBlock(blockId) + locations + } + + // Test if two attempts to 2x replication returns same set of locations + val a1Locs = putBlockAndGetLocations("a1", storageLevel2x) + assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs, + "Inserting a 2x replicated block second time gave different locations from the first") + + // Test if two attempts to 3x replication returns same set of locations + val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x) + assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x, + "Inserting a 3x replicated block second time gave different locations from the first") + + // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication + val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x) + assert( + a2Locs2x.subsetOf(a2Locs3x), + "Inserting a with 2x replication gave locations that are not a subset of locations" + + s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}" + ) + + // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication + val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x) + assert( + a2Locs3x.subsetOf(a2Locs4x), + "Inserting a with 4x replication gave locations that are not a superset of locations " + + s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}" + ) + + // Test if 3x replication of two different blocks gives two different sets of locations + val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x) + assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication") + } + + test("block replication - replication failures") { + /* + Create a system of three block managers / stores. One of them (say, failableStore) + cannot receive blocks. So attempts to use that as replication target fails. + + +-----------/fails/-----------> failableStore + | + normalStore + | + +-----------/works/-----------> anotherNormalStore + + We are first going to add a normal block manager (i.e. normalStore) and the failable block + manager (i.e. failableStore), and test whether 2x replication fails to create two + copies of a block. Then we are going to add another normal block manager + (i.e., anotherNormalStore), and test that now 2x replication works as the + new store will be used for replication. + */ + + // Add a normal block manager + val store = makeBlockManager(10000, "store") + + // Insert a block with 2x replication and return the number of copies of the block + def replicateAndGetNumCopies(blockId: String): Int = { + store.putSingle(blockId, new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK_2) + val numLocations = master.getLocations(blockId).size + allStores.foreach { _.removeBlock(blockId) } + numLocations + } + + // Add a failable block manager with a mock transfer service that does not + // allow receiving of blocks. So attempts to use it as a replication target will fail. + val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work + when(failableTransfer.hostName).thenReturn("some-hostname") + when(failableTransfer.port).thenReturn(1000) + val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + allStores += failableStore // so that this gets stopped after test + assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) + + // Test that 2x replication fails by creating only one copy of the block + assert(replicateAndGetNumCopies("a1") === 1) + + // Add another normal block manager and test that 2x replication works + makeBlockManager(10000, "anotherStore") + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a2") === 2) + } + } + + test("block replication - addition and deletion of block managers") { + val blockSize = 1000 + val storeSize = 10000 + val initialStores = (1 to 2).map { i => makeBlockManager(storeSize, s"store$i") } + + // Insert a block with given replication factor and return the number of copies of the block\ + def replicateAndGetNumCopies(blockId: String, replicationFactor: Int): Int = { + val storageLevel = StorageLevel(true, true, false, true, replicationFactor) + initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel) + val numLocations = master.getLocations(blockId).size + allStores.foreach { _.removeBlock(blockId) } + numLocations + } + + // 2x replication should work, 3x replication should only replicate 2x + assert(replicateAndGetNumCopies("a1", 2) === 2) + assert(replicateAndGetNumCopies("a2", 3) === 2) + + // Add another store, 3x replication should work now, 4x replication should only replicate 3x + val newStore1 = makeBlockManager(storeSize, s"newstore1") + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a3", 3) === 3) + } + assert(replicateAndGetNumCopies("a4", 4) === 3) + + // Add another store, 4x replication should work now + val newStore2 = makeBlockManager(storeSize, s"newstore2") + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a5", 4) === 4) + } + + // Remove all but the 1st store, 2x replication should fail + (initialStores.tail ++ Seq(newStore1, newStore2)).foreach { + store => + master.removeExecutor(store.blockManagerId.executorId) + store.stop() + } + assert(replicateAndGetNumCopies("a6", 2) === 1) + + // Add new stores, 3x replication should work + val newStores = (3 to 5).map { + i => makeBlockManager(storeSize, s"newstore$i") + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a7", 3) === 3) + } + } + + /** + * Test replication of blocks with different storage levels (various combinations of + * memory, disk & serialization). For each storage level, this function tests every store + * whether the block is present and also tests the master whether its knowledge of blocks + * is correct. Then it also drops the block from memory of each store (using LRU) and + * again checks whether the master's knowledge gets updated. + */ + private def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { + import org.apache.spark.storage.StorageLevel._ + + assert(maxReplication > 1, + s"Cannot test replication factor $maxReplication") + + // storage levels to test with the given replication factor + + val storeSize = 10000 + val blockSize = 1000 + + // As many stores as the replication factor + val stores = (1 to maxReplication).map { + i => makeBlockManager(storeSize, s"store$i") + } + + storageLevels.foreach { storageLevel => + // Put the block into one of the stores + val blockId = new TestBlockId( + "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) + stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + + // Assert that master know two locations for the block + val blockLocations = master.getLocations(blockId).map(_.executorId).toSet + assert(blockLocations.size === storageLevel.replication, + s"master did not have ${storageLevel.replication} locations for $blockId") + + // Test state of the stores that contain the block + stores.filter { + testStore => blockLocations.contains(testStore.blockManagerId.executorId) + }.foreach { testStore => + val testStoreName = testStore.blockManagerId.executorId + assert(testStore.getLocal(blockId).isDefined, s"$blockId was not found in $testStoreName") + assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), + s"master does not have status for ${blockId.name} in $testStoreName") + + val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId) + + // Assert that block status in the master for this store has expected storage level + assert( + blockStatus.storageLevel.useDisk === storageLevel.useDisk && + blockStatus.storageLevel.useMemory === storageLevel.useMemory && + blockStatus.storageLevel.useOffHeap === storageLevel.useOffHeap && + blockStatus.storageLevel.deserialized === storageLevel.deserialized, + s"master does not know correct storage level for ${blockId.name} in $testStoreName") + + // Assert that the block status in the master for this store has correct memory usage info + assert(!blockStatus.storageLevel.useMemory || blockStatus.memSize >= blockSize, + s"master does not know size of ${blockId.name} stored in memory of $testStoreName") + + + // If the block is supposed to be in memory, then drop the copy of the block in + // this store test whether master is updated with zero memory usage this store + if (storageLevel.useMemory) { + // Force the block to be dropped by adding a number of dummy blocks + (1 to 10).foreach { + i => + testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER) + } + (1 to 10).foreach { + i => testStore.removeBlock(s"dummy-block-$i") + } + + val newBlockStatusOption = master.getBlockStatus(blockId).get(testStore.blockManagerId) + + // Assert that the block status in the master either does not exist (block removed + // from every store) or has zero memory usage for this store + assert( + newBlockStatusOption.isEmpty || newBlockStatusOption.get.memSize === 0, + s"after dropping, master does not know size of ${blockId.name} " + + s"stored in memory of $testStoreName" + ) + } + + // If the block is supposed to be in disk (after dropping or otherwise, then + // test whether master has correct disk usage for this store + if (storageLevel.useDisk) { + assert(master.getBlockStatus(blockId)(testStore.blockManagerId).diskSize >= blockSize, + s"after dropping, master does not know size of ${blockId.name} " + + s"stored in disk of $testStoreName" + ) + } + } + master.removeBlock(blockId) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e251660dae5de..9d96202a3e7ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,8 +21,6 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import java.util.concurrent.TimeUnit -import org.apache.spark.network.nio.NioBlockTransferService - import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await import scala.concurrent.duration._ @@ -35,13 +33,13 @@ import akka.util.Timeout import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod +import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -189,7 +187,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter store = makeBlockManager(2000, "exec1") store2 = makeBlockManager(2000, "exec2") - val peers = master.getPeers(store.blockManagerId, 1) + val peers = master.getPeers(store.blockManagerId) assert(peers.size === 1, "master did not return the other manager as a peer") assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") @@ -448,7 +446,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter val list2DiskGet = store.get("list2disk") assert(list2DiskGet.isDefined, "list2memory expected to be in store") assert(list2DiskGet.get.data.size === 3) - System.out.println(list2DiskGet) // We don't know the exact size of the data on disk, but it should certainly be > 0. assert(list2DiskGet.get.inputMetrics.bytesRead > 0) assert(list2DiskGet.get.inputMetrics.readMethod === DataReadMethod.Disk) From 127e97bee1e6aae7b70263bc5944b7be6f4e6fea Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 2 Oct 2014 13:52:54 -0700 Subject: [PATCH 172/315] [SPARK-3632] ConnectionManager can run out of receive threads with authentication on If you turn authentication on and you are using a lot of executors. There is a chance that all the of the threads in the handleMessageExecutor could be waiting to send a message because they are blocked waiting on authentication to happen. This can cause a temporary deadlock until the connection times out. To fix it, I got rid of the wait/notify and use a single outbox but only send security messages from it until authentication has completed. Author: Thomas Graves Closes #2484 from tgravescs/cm_threads_auth and squashes the following commits: a0a961d [Thomas Graves] give it a type b6bc80b [Thomas Graves] Rework comments d6d4175 [Thomas Graves] update from comments 081b765 [Thomas Graves] cleanup 4d7f8f5 [Thomas Graves] Change to not use wait/notify while waiting for authentication --- .../org/apache/spark/SecurityManager.scala | 5 +- .../apache/spark/network/nio/Connection.scala | 65 +++++++++++------ .../spark/network/nio/ConnectionManager.scala | 72 +++++-------------- 3 files changed, 63 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 3832a780ec4bc..0e0f1a7b2377e 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to * match up incoming messages with connections waiting for authentication. - * If its acting as a client and trying to send a message to another ConnectionManager, - * it blocks the thread calling sendMessage until the SASL negotiation has occurred. * The ConnectionManager tracks all the sendingConnections using the ConnectionId - * and waits for the response from the server and does the handshake. + * and waits for the response from the server and does the handshake before sending + * the real message. * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 18172d359cb35..f368209980f93 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,23 +20,27 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.LinkedList import org.apache.spark._ -import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap} private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) + val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, + val securityMgr: SecurityManager) extends Logging { var sparkSaslServer: SparkSaslServer = null var sparkSaslClient: SparkSaslClient = null - def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { + def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId, + securityMgr_ : SecurityManager) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_) + channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), + id_, securityMgr_) } channel.configureBlocking(false) @@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val remoteAddress = getRemoteAddress() - /** - * Used to synchronize client requests: client's work-related requests must - * wait until SASL authentication completes. - */ - private val authenticated = new Object() - - def getAuthenticated(): Object = authenticated - def isSaslComplete(): Boolean def resetForceReregister(): Boolean @@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private[nio] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId, id_ : ConnectionId) - extends Connection(SocketChannel.open, selector_, remoteId_, id_) { + remoteId_ : ConnectionManagerId, id_ : ConnectionId, + securityMgr_ : SecurityManager) + extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) { def isSaslComplete(): Boolean = { if (sparkSaslClient != null) sparkSaslClient.isComplete() else false } private class Outbox { - val messages = new Queue[Message]() + val messages = new LinkedList[Message]() val defaultChunkSize = 65536 var nextMessageToBeUsed = 0 def addMessage(message: Message) { messages.synchronized { - /* messages += message */ - messages.enqueue(message) + messages.add(message) logDebug("Added [" + message + "] to outbox for sending to " + "[" + getRemoteConnectionManagerId() + "]") } @@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, while (!messages.isEmpty) { /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ /* val message = messages(nextMessageToBeUsed) */ - val message = messages.dequeue() + + val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) { + // only allow sending of security messages until sasl is complete + var pos = 0 + var securityMsg: Message = null + while (pos < messages.size() && securityMsg == null) { + if (messages.get(pos).isSecurityNeg) { + securityMsg = messages.remove(pos) + } + pos = pos + 1 + } + // didn't find any security messages and auth isn't completed so return + if (securityMsg == null) return None + securityMsg + } else { + messages.removeFirst() + } + val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { - messages.enqueue(message) + messages.add(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { logDebug( @@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, changeConnectionKeyInterest(DEFAULT_INTEREST) } + def registerAfterAuth(): Unit = { + outbox.synchronized { + needForceReregister = true + } + if (channel.isConnected) { + registerInterest() + } + } + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) @@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, private[spark] class ReceivingConnection( channel_ : SocketChannel, selector_ : Selector, - id_ : ConnectionId) - extends Connection(channel_, selector_, id_) { + id_ : ConnectionId, + securityMgr_ : SecurityManager) + extends Connection(channel_, selector_, id_, securityMgr_) { def isSaslComplete(): Boolean = { if (sparkSaslServer != null) sparkSaslServer.isComplete() else false diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 5aa7e94943561..01cd27a907eea 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps import org.apache.spark._ -import org.apache.spark.util.{SystemClock, Utils} +import org.apache.spark.util.Utils private[nio] class ConnectionManager( @@ -65,8 +65,6 @@ private[nio] class ConnectionManager( private val selector = SelectorProvider.provider.openSelector() private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) - // default to 30 second timeout waiting for authentication - private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) private val handleMessageExecutor = new ThreadPoolExecutor( @@ -409,7 +407,8 @@ private[nio] class ConnectionManager( while (newChannel != null) { try { val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) + val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId, + securityManager) newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) @@ -527,9 +526,8 @@ private[nio] class ConnectionManager( if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.getAuthenticated().synchronized { - waitingConn.getAuthenticated().notifyAll() - } + waitingConn.registerAfterAuth() + wakeupSelector() return } else { var replyToken : Array[Byte] = null @@ -538,9 +536,8 @@ private[nio] class ConnectionManager( if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.getAuthenticated().synchronized { - waitingConn.getAuthenticated().notifyAll() - } + waitingConn.registerAfterAuth() + wakeupSelector() return } val securityMsgResp = SecurityMessage.fromResponse(replyToken, @@ -574,9 +571,11 @@ private[nio] class ConnectionManager( } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId) + logDebug("Server sasl completed: " + connection.connectionId + + " for: " + connectionId) } else { - logDebug("Server sasl not completed: " + connection.connectionId) + logDebug("Server sasl not completed: " + connection.connectionId + + " for: " + connectionId) } if (replyToken != null) { val securityMsgResp = SecurityMessage.fromResponse(replyToken, @@ -723,7 +722,8 @@ private[nio] class ConnectionManager( if (message == null) throw new Exception("Error creating security message") connectionsAwaitingSasl += ((conn.connectionId, conn)) sendSecurityMessage(connManagerId, message) - logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) + logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId + + " to: " + connManagerId) } catch { case e: Exception => { logError("Error getting first response from the SaslClient.", e) @@ -744,7 +744,7 @@ private[nio] class ConnectionManager( val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, - newConnectionId) + newConnectionId, securityManager) logInfo("creating new sending connection for security! " + newConnectionId ) registerRequests.enqueue(newConnection) @@ -769,61 +769,23 @@ private[nio] class ConnectionManager( connectionManagerId.port) val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, - newConnectionId) + newConnectionId, securityManager) logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) newConnection } val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) - } + message.senderAddress = id.toSocketAddress() logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + "connectionid: " + connection.connectionId) if (authEnabled) { - // if we aren't authenticated yet lets block the senders until authentication completes - try { - connection.getAuthenticated().synchronized { - val clock = SystemClock - val startTime = clock.getTime() - - while (!connection.isSaslComplete()) { - logDebug("getAuthenticated wait connectionid: " + connection.connectionId) - // have timeout in case remote side never responds - connection.getAuthenticated().wait(500) - if (((clock.getTime() - startTime) >= (authTimeout * 1000)) - && (!connection.isSaslComplete())) { - // took to long to authenticate the connection, something probably went wrong - throw new Exception("Took to long for authentication to " + connectionManagerId + - ", waited " + authTimeout + "seconds, failing.") - } - } - } - } catch { - case e: Exception => logError("Exception while waiting for authentication.", e) - - // need to tell sender it failed - messageStatuses.synchronized { - val s = messageStatuses.get(message.id) - s match { - case Some(msgStatus) => { - messageStatuses -= message.id - logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.markDone(None) - } - case None => { - logError("no messageStatus for failed message id: " + message.id) - } - } - } - } + checkSendAuthFirst(connectionManagerId, connection) } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) - wakeupSelector() } From 8081ce8bd111923db143abc55bb6ef9793eece35 Mon Sep 17 00:00:00 2001 From: scwf Date: Thu, 2 Oct 2014 17:47:56 -0700 Subject: [PATCH 173/315] [SPARK-3755][Core] avoid trying privileged port when request a non-privileged port pwendell, ```tryPort``` is not compatible with old code in last PR, this is to fix it. And after discuss with srowen renamed the title to "avoid trying privileged port when request a non-privileged port". Plz refer to the discuss for detail. Author: scwf Closes #2623 from scwf/1-1024 and squashes the following commits: 10a4437 [scwf] add comment de3fd17 [scwf] do not try privileged port when request a non-privileged port 42cb0fa [scwf] make tryPort compatible with old code cb8cc76 [scwf] do not use port 1 - 1024 --- core/src/main/scala/org/apache/spark/util/Utils.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b3025c6ec3364..9399ddab76331 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1439,7 +1439,12 @@ private[spark] object Utils extends Logging { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port - val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536 + val tryPort = if (startPort == 0) { + startPort + } else { + // If the new port wraps around, do not try a privilege port + ((startPort + offset - 1024) % (65536 - 1024)) + 1024 + } try { val (service, port) = startService(tryPort) logInfo(s"Successfully started service$serviceString on port $port.") From 42d5077fd3f2c37d1cd23f4c81aa89286a74cb40 Mon Sep 17 00:00:00 2001 From: Eric Eijkelenboom Date: Thu, 2 Oct 2014 18:04:38 -0700 Subject: [PATCH 174/315] [DEPLOY] SPARK-3759: Return the exit code of the driver process SparkSubmitDriverBootstrapper.scala now returns the exit code of the driver process, instead of always returning 0. Author: Eric Eijkelenboom Closes #2628 from ericeijkelenboom/master and squashes the following commits: cc4a571 [Eric Eijkelenboom] Return the exit code of the driver process --- .../apache/spark/deploy/SparkSubmitDriverBootstrapper.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 38b5d8e1739d0..a64170a47bc1c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -154,7 +154,8 @@ private[spark] object SparkSubmitDriverBootstrapper { process.destroy() } } - process.waitFor() + val returnCode = process.waitFor() + sys.exit(returnCode) } } From 7de4e50a01e90bcf88e0b721b2b15a5162373d56 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 2 Oct 2014 19:32:21 -0700 Subject: [PATCH 175/315] [SQL] Initilize session state before creating CommandProcessor With the old ordering it was possible for commands in the HiveDriver to NPE due to the lack of configuration in the threadlocal session state. Author: Michael Armbrust Closes #2635 from marmbrus/initOrder and squashes the following commits: 9749850 [Michael Armbrust] Initilize session state before creating CommandProcessor --- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fdb56901f9ddb..8bcc098bbb620 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -281,13 +281,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = { try { + // Session state must be initilized before the CommandProcessor is created . + SessionState.start(sessionState) + val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf) - SessionState.start(sessionState) - proc match { case driver: Driver => driver.init() From 1c90347a4bba12df7b76d282a7dbac8e555e049f Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 2 Oct 2014 20:04:33 -0700 Subject: [PATCH 176/315] [SPARK-3654][SQL] Implement all extended HiveQL statements/commands with a separate parser combinator Created separate parser for hql. It preparses the commands like cache,uncache,add jar etc.. and then parses with HiveQl Author: ravipesala Closes #2590 from ravipesala/SPARK-3654 and squashes the following commits: bbca7dd [ravipesala] Fixed code as per admin comments. ae9290a [ravipesala] Fixed style issues as per Admin comments 898ed81 [ravipesala] Removed spaces fb24edf [ravipesala] Updated the code as per admin comments 8947d37 [ravipesala] Removed duplicate code ba26cd1 [ravipesala] Created seperate parser for hql.It pre parses the commands like cache,uncache,add jar etc.. and then parses with HiveQl --- .../spark/sql/hive/ExtendedHiveQlParser.scala | 135 ++++++++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 57 ++------ .../spark/sql/hive/CachedTableSuite.scala | 6 + 3 files changed, 154 insertions(+), 44 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala new file mode 100644 index 0000000000000..e7e1cb980c2ae --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.language.implicitConversions +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * A parser that recognizes all HiveQL constructs together with several Spark SQL specific + * extensions like CACHE TABLE and UNCACHE TABLE. + */ +private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { + + def apply(input: String): LogicalPlan = { + // Special-case out set commands since the value fields can be + // complex to handle without RegexParsers. Also this approach + // is clearer for the several possible cases of set commands. + if (input.trim.toLowerCase.startsWith("set")) { + input.trim.drop(3).split("=", 2).map(_.trim) match { + case Array("") => // "set" + SetCommand(None, None) + case Array(key) => // "set key" + SetCommand(Some(key), None) + case Array(key, value) => // "set key=value" + SetCommand(Some(key), Some(value)) + } + } else if (input.trim.startsWith("!")) { + ShellCommand(input.drop(1)) + } else { + phrase(query)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => sys.error(x.toString) + } + } + } + + protected case class Keyword(str: String) + + protected val CACHE = Keyword("CACHE") + protected val SET = Keyword("SET") + protected val ADD = Keyword("ADD") + protected val JAR = Keyword("JAR") + protected val TABLE = Keyword("TABLE") + protected val AS = Keyword("AS") + protected val UNCACHE = Keyword("UNCACHE") + protected val FILE = Keyword("FILE") + protected val DFS = Keyword("DFS") + protected val SOURCE = Keyword("SOURCE") + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected def allCaseConverse(k: String): Parser[String] = + lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _) + + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + protected lazy val query: Parser[LogicalPlan] = + cache | uncache | addJar | addFile | dfs | source | hiveQl + + protected lazy val hiveQl: Parser[LogicalPlan] = + remainingQuery ^^ { + case r => HiveQl.createPlan(r.trim()) + } + + /** It returns all remaining query */ + protected lazy val remainingQuery: Parser[String] = new Parser[String] { + def apply(in: Input) = + Success( + in.source.subSequence(in.offset, in.source.length).toString, + in.drop(in.source.length())) + } + + /** It returns all query */ + protected lazy val allQuery: Parser[String] = new Parser[String] { + def apply(in: Input) = + Success(in.source.toString, in.drop(in.source.length())) + } + + protected lazy val cache: Parser[LogicalPlan] = + CACHE ~ TABLE ~> ident ~ opt(AS ~> hiveQl) ^^ { + case tableName ~ None => CacheCommand(tableName, true) + case tableName ~ Some(plan) => + CacheTableAsSelectCommand(tableName, plan) + } + + protected lazy val uncache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident ^^ { + case tableName => CacheCommand(tableName, false) + } + + protected lazy val addJar: Parser[LogicalPlan] = + ADD ~ JAR ~> remainingQuery ^^ { + case rq => AddJar(rq.trim()) + } + + protected lazy val addFile: Parser[LogicalPlan] = + ADD ~ FILE ~> remainingQuery ^^ { + case rq => AddFile(rq.trim()) + } + + protected lazy val dfs: Parser[LogicalPlan] = + DFS ~> allQuery ^^ { + case aq => NativeCommand(aq.trim()) + } + + protected lazy val source: Parser[LogicalPlan] = + SOURCE ~> remainingQuery ^^ { + case rq => SourceCommand(rq.trim()) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 4f3f808c93dc8..6bb42eeb0550d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -126,6 +126,9 @@ private[hive] object HiveQl { "TOK_CREATETABLE", "TOK_DESCTABLE" ) ++ nativeCommands + + // It parses hive sql query along with with several Spark SQL specific extensions + protected val hiveSqlParser = new ExtendedHiveQlParser /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations @@ -215,40 +218,19 @@ private[hive] object HiveQl { def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = { + def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql) + + /** Creates LogicalPlan for a given HiveQL string. */ + def createPlan(sql: String) = { try { - if (sql.trim.toLowerCase.startsWith("set")) { - // Split in two parts since we treat the part before the first "=" - // as key, and the part after as value, which may contain other "=" signs. - sql.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else if (sql.trim.toLowerCase.startsWith("cache table")) { - sql.trim.drop(12).trim.split(" ").toSeq match { - case Seq(tableName) => - CacheCommand(tableName, true) - case Seq(tableName, _, select @ _*) => - CacheTableAsSelectCommand(tableName, createPlan(select.mkString(" ").trim)) - } - } else if (sql.trim.toLowerCase.startsWith("uncache table")) { - CacheCommand(sql.trim.drop(14).trim, false) - } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8).trim) - } else if (sql.trim.toLowerCase.startsWith("add file")) { - AddFile(sql.trim.drop(9)) - } else if (sql.trim.toLowerCase.startsWith("dfs")) { + val tree = getAst(sql) + if (nativeCommands contains tree.getText) { NativeCommand(sql) - } else if (sql.trim.startsWith("source")) { - SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) - } else if (sql.trim.startsWith("!")) { - ShellCommand(sql.drop(1)) } else { - createPlan(sql) + nodeToPlan(tree) match { + case NativePlaceholder => NativeCommand(sql) + case other => other + } } } catch { case e: Exception => throw new ParseException(sql, e) @@ -259,19 +241,6 @@ private[hive] object HiveQl { """.stripMargin) } } - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String) = { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { - NativeCommand(sql) - } else { - nodeToPlan(tree) match { - case NativePlaceholder => NativeCommand(sql) - case other => other - } - } - } def parseDdl(ddl: String): Seq[Attribute] = { val tree = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 188579edd7bdd..b3057cd618c66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -88,4 +88,10 @@ class CachedTableSuite extends HiveComparisonTest { } assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } + + test("'CACHE TABLE tableName AS SELECT ..'") { + TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src") + assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + TestHive.uncacheTable("testCacheTable") + } } From 2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Fri, 3 Oct 2014 03:26:17 -0700 Subject: [PATCH 177/315] [SPARK-3366][MLLIB]Compute best splits distributively in decision tree Currently, all best splits are computed on the driver, which makes the driver a bottleneck for both communication and computation. This PR fix this problem by computed best splits on executors. Instead of send all aggregate stats to the driver node, we can send aggregate stats for a node to a particular executor, using `reduceByKey` operation, then we can compute best split for this node there. Implementation details: Each node now has a nodeStatsAggregator, which save aggregate stats for all features and bins. First use mapPartition to compute node aggregate stats for all nodes in each partition. Then transform node aggregate stats to (nodeIndex, nodeStatsAggregator) pairs and use to `reduceByKey` operation to combine nodeStatsAggregator for the same node. After all stats have been combined, best splits can be computed for each node based on the node aggregate stats. Best split result is collected to driver to construct the decision tree. CC: mengxr manishamde jkbradley, please help me review this, thanks. Author: qiping.lqp Author: chouqin Closes #2595 from chouqin/dt-dist-agg and squashes the following commits: db0d24a [chouqin] fix a minor bug and adjust code a0d9de3 [chouqin] adjust code based on comments 9f201a6 [chouqin] fix bug: statsSize -> allStatsSize a8a7ed0 [chouqin] Merge branch 'master' of https://github.com/apache/spark into dt-dist-agg f13b346 [chouqin] adjust randomforest comments c32636e [chouqin] adjust code based on comments ac6a505 [chouqin] adjust code based on comments 7bbb787 [chouqin] add comments bdd2a63 [qiping.lqp] fix test suite a75df27 [qiping.lqp] fix test suite b5b0bc2 [qiping.lqp] fix style e76414f [qiping.lqp] fix testsuite 748bd45 [qiping.lqp] fix type-mismatch bug 24eacd8 [qiping.lqp] fix type-mismatch bug 5f63d6c [qiping.lqp] add multiclassification using One-Vs-All strategy 4f56496 [qiping.lqp] fix bug f00fc22 [qiping.lqp] fix bug 532993a [qiping.lqp] Compute best splits distributively in decision tree --- .../spark/mllib/tree/DecisionTree.scala | 140 ++++++--- .../spark/mllib/tree/RandomForest.scala | 5 +- .../mllib/tree/impl/DTStatsAggregator.scala | 292 +++++------------- .../tree/model/InformationGainStats.scala | 11 + .../spark/mllib/tree/RandomForestSuite.scala | 1 + 5 files changed, 182 insertions(+), 267 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b7dc373ebd9cc..b311d10023894 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -23,7 +23,6 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy @@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.SparkContext._ /** @@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging { * for each subset is updated. * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). + * each (feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. * @param instanceWeight Weight (importance) of instance in dataset. @@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging { private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int, bins: Array[Array[Bin]], unorderedFeatures: Set[Int], instanceWeight: Double, @@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging { // Use all features agg.metadata.numFeatures } - val nodeOffset = agg.getNodeOffset(nodeIndex) // Iterate over features. var featureIndexIdx = 0 while (featureIndexIdx < numFeaturesPerNode) { @@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) + agg.getLeftRightFeatureOffsets(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { - agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } else { - agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, + agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1 @@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging { } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label, - instanceWeight) + agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) } featureIndexIdx += 1 } @@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging { * For each feature, the sufficient statistics of one bin are updated. * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). + * each (feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int, instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label - val nodeOffset = agg.getNodeOffset(nodeIndex) + // Iterate over features. if (featuresForNode.nonEmpty) { // Use subsampled features var featureIndexIdx = 0 while (featureIndexIdx < featuresForNode.get.size) { val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) - agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight) + agg.update(featureIndexIdx, binIndex, label, instanceWeight) featureIndexIdx += 1 } } else { @@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures) { val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight) + agg.update(featureIndex, binIndex, label, instanceWeight) featureIndex += 1 } } @@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging { * @return agg */ def binSeqOp( - agg: DTStatsAggregator, - baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = { + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) @@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging { val featuresForNode = nodeInfo.featureSubset val instanceWeight = baggedPoint.subsampleWeights(treeIndex) if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode) + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { - mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures, + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, instanceWeight, featuresForNode) } } @@ -518,30 +512,76 @@ object DecisionTree extends Serializable with Logging { agg } - // Calculate bin aggregates. - timer.start("aggregation") - val binAggregates: DTStatsAggregator = { - val initAgg = if (metadata.subsamplingFeatures) { - new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo) - } else { - new DTStatsAggregatorFixedFeatures(metadata, numNodes) + /** + * Get node index in group --> features indices map, + * which is a short cut to find feature indices for a node given node index in group + * @param treeToNodeToIndexInfo + * @return + */ + def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) + : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) { + None + } else { + val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]() + treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => + nodeIdToNodeInfo.values.foreach { nodeIndexInfo => + assert(nodeIndexInfo.featureSubset.isDefined) + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + } } - input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) + Some(mutableNodeToFeatures.toMap) } - timer.stop("aggregation") // Calculate best splits for all nodes in the group timer.start("chooseSplits") + // In each partition, iterate all instances and compute aggregate stats for each node, + // yield an (nodeIndex, nodeAggregateStats) pair for each node. + // After a `reduceByKey` operation, + // stats of a node will be shuffled to a particular partition and be combined together, + // then best splits for nodes are found there. + // Finally, only best Splits for nodes are collected to driver to construct decision tree. + val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) + val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) + val nodeToBestSplits = + input.mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOp(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + }.reduceByKey((a, b) => a.merge(b)) + .map { case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + + // find best split for each node + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(aggStats, splits, featuresForNode) + (nodeIndex, (split, stats, predict)) + }.collectAsMap() + + timer.stop("chooseSplits") + // Iterate over all nodes in this group. nodesForGroup.foreach { case (treeIndex, nodesForTree) => nodesForTree.foreach { node => val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup - val featuresForNode = nodeInfo.featureSubset val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode) + nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. @@ -565,7 +605,7 @@ object DecisionTree extends Serializable with Logging { } } } - timer.stop("chooseSplits") + } /** @@ -633,36 +673,33 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. * @param binAggregates Bin statistics. - * @param nodeIndex Index into aggregates for node to split in this group. * @return tuple for best split: (Split, information gain, prediction at node) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, - nodeIndex: Int, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { - val metadata: DecisionTreeMetadata = binAggregates.metadata - // calculate predict only once var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx => + val (bestSplit, bestSplitStats) = + Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => val featureIndex = if (featuresForNode.nonEmpty) { featuresForNode.get.apply(featureIndexIdx) } else { featureIndexIdx } - val numSplits = metadata.numSplits(featureIndex) - if (metadata.isContinuous(featureIndex)) { + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) var splitIndex = 0 while (splitIndex < numSplits) { - binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } // Find best split. @@ -672,27 +709,29 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (metadata.isUnordered(featureIndex)) { + } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) + binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) - val numBins = metadata.numBins(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numBins = binAggregates.metadata.numBins(featureIndex) /* Each bin is one category (feature value). * The bins are ordered based on centroidForCategories, and this ordering determines which @@ -700,7 +739,7 @@ object DecisionTree extends Serializable with Logging { * * centroidForCategories is a list: (category, centroid) */ - val centroidForCategories = if (metadata.isMulticlass) { + val centroidForCategories = if (binAggregates.metadata.isMulticlass) { // For categorical variables in multiclass classification, // the bins are ordered by the impurity of their corresponding labels. Range(0, numBins).map { case featureValue => @@ -741,7 +780,7 @@ object DecisionTree extends Serializable with Logging { while (splitIndex < numSplits) { val currentCategory = categoriesSortedByCentroid(splitIndex)._1 val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory) + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) splitIndex += 1 } // lastCategory = index of bin with total aggregates for this (node, feature) @@ -756,7 +795,8 @@ object DecisionTree extends Serializable with Logging { binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 7fa7725e79e46..fa7a26f17c3ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -171,8 +171,8 @@ private class RandomForest ( // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - DecisionTree.findBestSplits(baggedInput, - metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, + treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) timer.stop("findBestSplits") } @@ -382,6 +382,7 @@ object RandomForest extends Serializable with Logging { * @param maxMemoryUsage Bound on size of aggregate statistics. * @return (nodesForGroup, treeToNodeToIndexInfo). * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * * treeToNodeToIndexInfo holds indices selected features for each node: * treeIndex --> (global) node index --> (node index in group, feature indices). * The (global) node index is the index in the tree; the node index in group is the diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index d49df7a016375..55f422dff0d71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -17,17 +17,19 @@ package org.apache.spark.mllib.tree.impl -import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.impurity._ + + /** - * DecisionTree statistics aggregator. - * This holds a flat array of statistics for a set of (nodes, features, bins) + * DecisionTree statistics aggregator for a node. + * This holds a flat array of statistics for a set of (features, bins) * and helps with indexing. * This class is abstract to support learning with and without feature subsampling. */ -private[tree] abstract class DTStatsAggregator( - val metadata: DecisionTreeMetadata) extends Serializable { +private[tree] class DTStatsAggregator( + val metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]) extends Serializable { /** * [[ImpurityAggregator]] instance specifying the impurity type. @@ -42,7 +44,25 @@ private[tree] abstract class DTStatsAggregator( /** * Number of elements (Double values) used for the sufficient statistics of each bin. */ - val statsSize: Int = impurityAggregator.statsSize + private val statsSize: Int = impurityAggregator.statsSize + + /** + * Number of bins for each feature. This is indexed by the feature index. + */ + private val numBins: Array[Int] = { + if (featureSubset.isDefined) { + featureSubset.get.map(metadata.numBins(_)) + } else { + metadata.numBins + } + } + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ + private val featureOffsets: Array[Int] = { + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } /** * Indicator for each feature of whether that feature is an unordered feature. @@ -51,107 +71,95 @@ private[tree] abstract class DTStatsAggregator( def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) /** - * Total number of elements stored in this aggregator. + * Total number of elements stored in this aggregator */ - def allStatsSize: Int + private val allStatsSize: Int = featureOffsets.last /** - * Get flat array of elements stored in this aggregator. + * Flat array of elements. + * Index for start of stats for a (feature, bin) is: + * index = featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, + * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) + * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ - protected def allStats: Array[Double] + private val allStats: Array[Double] = new Array[Double](allStatsSize) + /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset - * from [[getNodeFeatureOffset]]. + * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. * For unordered features, this is a pre-computed * (node, feature, left/right child) offset from - * [[getLeftRightNodeFeatureOffsets]]. + * [[getLeftRightFeatureOffsets]]. */ - def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = { - impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize) + def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } /** - * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + * Update the stats for a given (feature, bin) for ordered features, using the given label. */ - def update( - nodeIndex: Int, - featureIndex: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize + def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { + val i = featureOffsets(featureIndex) + binIndex * statsSize impurityAggregator.update(allStats, i, label, instanceWeight) } - /** - * Pre-compute node offset for use with [[nodeUpdate]]. - */ - def getNodeOffset(nodeIndex: Int): Int - /** * Faster version of [[update]]. - * Update the stats for a given (node, feature, bin) for ordered features, using the given label. - * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. + * Update the stats for a given (feature, bin), using the given label. + * @param featureOffset For ordered features, this is a pre-computed feature offset + * from [[getFeatureOffset]]. + * For unordered features, this is a pre-computed + * (feature, left/right child) offset from + * [[getLeftRightFeatureOffsets]]. */ - def nodeUpdate( - nodeOffset: Int, - nodeIndex: Int, - featureIndex: Int, + def featureUpdate( + featureOffset: Int, binIndex: Int, label: Double, - instanceWeight: Double): Unit + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, + label, instanceWeight) + } /** - * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * Pre-compute feature offset for use with [[featureUpdate]]. * For ordered features only. */ - def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int + def getFeatureOffset(featureIndex: Int): Int = { + require(!isUnordered(featureIndex), + s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" + + s" for unordered feature $featureIndex.") + featureOffsets(featureIndex) + } /** - * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * Pre-compute feature offset for use with [[featureUpdate]]. * For unordered features only. */ - def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { + def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { require(isUnordered(featureIndex), - s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + + s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," + s" but was called for ordered feature $featureIndex.") - val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex) - (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize) - } - - /** - * Faster version of [[update]]. - * Update the stats for a given (node, feature, bin), using the given label. - * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset - * from [[getNodeFeatureOffset]]. - * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightNodeFeatureOffsets]]. - */ - def nodeFeatureUpdate( - nodeFeatureOffset: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label, - instanceWeight) + val baseOffset = featureOffsets(featureIndex) + (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } /** - * For a given (node, feature), merge the stats for two bins. - * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset - * from [[getNodeFeatureOffset]]. + * For a given feature, merge the stats for two bins. + * @param featureOffset For ordered features, this is a pre-computed feature offset + * from [[getFeatureOffset]]. * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightNodeFeatureOffsets]]. + * (feature, left/right child) offset from + * [[getLeftRightFeatureOffsets]]. * @param binIndex The other bin is merged into this bin. * @param otherBinIndex This bin is not modified. */ - def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { - impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize, - nodeFeatureOffset + otherBinIndex * statsSize) + def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, + featureOffset + otherBinIndex * statsSize) } /** @@ -161,7 +169,7 @@ private[tree] abstract class DTStatsAggregator( def merge(other: DTStatsAggregator): DTStatsAggregator = { require(allStatsSize == other.allStatsSize, s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." - + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") var i = 0 // TODO: Test BLAS.axpy while (i < allStatsSize) { @@ -171,149 +179,3 @@ private[tree] abstract class DTStatsAggregator( this } } - -/** - * DecisionTree statistics aggregator. - * This holds a flat array of statistics for a set of (nodes, features, bins) - * and helps with indexing. - * - * This instance of [[DTStatsAggregator]] is used when not subsampling features. - * - * @param numNodes Number of nodes to collect statistics for. - */ -private[tree] class DTStatsAggregatorFixedFeatures( - metadata: DecisionTreeMetadata, - numNodes: Int) extends DTStatsAggregator(metadata) { - - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - * Mapping: featureIndex --> offset - */ - private val featureOffsets: Array[Int] = { - metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. - */ - private val nodeStride: Int = featureOffsets.last - - override val allStatsSize: Int = numNodes * nodeStride - - /** - * Flat array of elements. - * Index for start of stats for a (node, feature, bin) is: - * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, the left child stats precede the right child stats - * in the binIndex order. - */ - override protected val allStats: Array[Double] = new Array[Double](allStatsSize) - - override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride - - override def nodeUpdate( - nodeOffset: Int, - nodeIndex: Int, - featureIndex: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) - } - - override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { - nodeIndex * nodeStride + featureOffsets(featureIndex) - } -} - -/** - * DecisionTree statistics aggregator. - * This holds a flat array of statistics for a set of (nodes, features, bins) - * and helps with indexing. - * - * This instance of [[DTStatsAggregator]] is used when subsampling features. - * - * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, - * where nodeIndexInfo stores the index in the group and the - * feature subsets (if using feature subsets). - */ -private[tree] class DTStatsAggregatorSubsampledFeatures( - metadata: DecisionTreeMetadata, - treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) { - - /** - * For each node, offset for each feature for calculating indices into the [[allStats]] array. - * Mapping: nodeIndex --> featureIndex --> offset - */ - private val featureOffsets: Array[Array[Int]] = { - val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum - val offsets = new Array[Array[Int]](numNodes) - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) => - nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) => - offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_)) - .scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - } - offsets - } - - /** - * For each node, offset for each feature for calculating indices into the [[allStats]] array. - */ - protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _) - - override val allStatsSize: Int = nodeOffsets.last - - /** - * Flat array of elements. - * Index for start of stats for a (node, feature, bin) is: - * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, the left child stats precede the right child stats - * in the binIndex order. - */ - override protected val allStats: Array[Double] = new Array[Double](allStatsSize) - - override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex) - - /** - * Faster version of [[update]]. - * Update the stats for a given (node, feature, bin) for ordered features, using the given label. - * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. - * @param featureIndex Index of feature in featuresForNodes(nodeIndex). - * Note: This is NOT the original feature index. - */ - override def nodeUpdate( - nodeOffset: Int, - nodeIndex: Int, - featureIndex: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) - } - - /** - * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. - * For ordered features only. - * @param featureIndex Index of feature in featuresForNodes(nodeIndex). - * Note: This is NOT the original feature index. - */ - override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { - nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex) - } -} - -private[tree] object DTStatsAggregator extends Serializable { - - /** - * Combines two aggregates (modifying the first) and returns the combination. - */ - def binCombOp( - agg1: DTStatsAggregator, - agg2: DTStatsAggregator): DTStatsAggregator = { - agg1.merge(agg2) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f3e2619bd8ba0..a89e71e115806 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -38,6 +38,17 @@ class InformationGainStats( "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" .format(gain, impurity, leftImpurity, rightImpurity) } + + override def equals(o: Any) = + o match { + case other: InformationGainStats => { + gain == other.gain && + impurity == other.impurity && + leftImpurity == other.leftImpurity && + rightImpurity == other.rightImpurity + } + case _ => false + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 30669fcd1c75b..20d372dc1d3ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { assert(nodesForGroup.size === numTrees, failString) assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree + if (numFeaturesPerNode == numFeatures) { // featureSubset values should all be None assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), From f0811f928e5b608e1a2cba3b6828ba0ed03b701d Mon Sep 17 00:00:00 2001 From: EugenCepoi Date: Fri, 3 Oct 2014 10:03:15 -0700 Subject: [PATCH 178/315] SPARK-2058: Overriding SPARK_HOME/conf with SPARK_CONF_DIR Update of PR #997. With this PR, setting SPARK_CONF_DIR overrides SPARK_HOME/conf (not only spark-defaults.conf and spark-env). Author: EugenCepoi Closes #2481 from EugenCepoi/SPARK-2058 and squashes the following commits: 0bb32c2 [EugenCepoi] use orElse orNull and fixing trailing percent in compute-classpath.cmd 77f35d7 [EugenCepoi] SPARK-2058: Overriding SPARK_HOME/conf with SPARK_CONF_DIR --- bin/compute-classpath.cmd | 8 +++- bin/compute-classpath.sh | 8 +++- .../spark/deploy/SparkSubmitArguments.scala | 42 ++++++++----------- .../spark/deploy/SparkSubmitSuite.scala | 34 ++++++++++++++- docs/configuration.md | 7 ++++ 5 files changed, 71 insertions(+), 28 deletions(-) diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 5ad52452a5c98..9b9e40321ea93 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -36,7 +36,13 @@ rem Load environment variables from conf\spark-env.cmd, if it exists if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%;%FWDIR%conf +set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% + +if "x%SPARK_CONF_DIR%"!="x" ( + set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% +) else ( + set CLASSPATH=%CLASSPATH%;%FWDIR%conf +) if exist "%FWDIR%RELEASE" ( for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 0f63e36d8aeca..905bbaf99b374 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -27,8 +27,14 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" . "$FWDIR"/bin/load-spark-env.sh +CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH" + # Build up classpath -CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH:$FWDIR/conf" +if [ -n "$SPARK_CONF_DIR" ]; then + CLASSPATH="$CLASSPATH:$SPARK_CONF_DIR" +else + CLASSPATH="$CLASSPATH:$FWDIR/conf" +fi ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 2b72c61cc8177..57b251ff47714 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -29,8 +29,9 @@ import org.apache.spark.util.Utils /** * Parses and encapsulates arguments from the spark-submit script. + * The env argument is used for testing. */ -private[spark] class SparkSubmitArguments(args: Seq[String]) { +private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -90,20 +91,12 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user if (propertiesFile == null) { - sys.env.get("SPARK_CONF_DIR").foreach { sparkConfDir => - val sep = File.separator - val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" - val file = new File(defaultPath) - if (file.exists()) { - propertiesFile = file.getAbsolutePath - } - } - } + val sep = File.separator + val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf") + val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig) - if (propertiesFile == null) { - sys.env.get("SPARK_HOME").foreach { sparkHome => - val sep = File.separator - val defaultPath = s"${sparkHome}${sep}conf${sep}spark-defaults.conf" + confDir.foreach { sparkConfDir => + val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" val file = new File(defaultPath) if (file.exists()) { propertiesFile = file.getAbsolutePath @@ -117,19 +110,18 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { // Use properties file as fallback for values which have a direct analog to // arguments in this script. - master = Option(master).getOrElse(properties.get("spark.master").orNull) - executorMemory = Option(executorMemory) - .getOrElse(properties.get("spark.executor.memory").orNull) - executorCores = Option(executorCores) - .getOrElse(properties.get("spark.executor.cores").orNull) + master = Option(master).orElse(properties.get("spark.master")).orNull + executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull + executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull totalExecutorCores = Option(totalExecutorCores) - .getOrElse(properties.get("spark.cores.max").orNull) - name = Option(name).getOrElse(properties.get("spark.app.name").orNull) - jars = Option(jars).getOrElse(properties.get("spark.jars").orNull) + .orElse(properties.get("spark.cores.max")) + .orNull + name = Option(name).orElse(properties.get("spark.app.name")).orNull + jars = Option(jars).orElse(properties.get("spark.jars")).orNull // This supports env vars in older versions of Spark - master = Option(master).getOrElse(System.getenv("MASTER")) - deployMode = Option(deployMode).getOrElse(System.getenv("DEPLOY_MODE")) + master = Option(master).orElse(env.get("MASTER")).orNull + deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && primaryResource != null) { @@ -182,7 +174,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { } if (master.startsWith("yarn")) { - val hasHadoopEnv = sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR") + val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { throw new Exception(s"When running with master '$master' " + "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 0c324d8bdf6a4..4cba90e8f2afe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{File, OutputStream, PrintStream} +import java.io._ import scala.collection.mutable.ArrayBuffer @@ -26,6 +26,7 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite import org.scalatest.Matchers +import com.google.common.io.Files class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { @@ -306,6 +307,21 @@ class SparkSubmitSuite extends FunSuite with Matchers { runSparkSubmit(args) } + test("SPARK_CONF_DIR overrides spark-defaults.conf") { + forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + unusedJar.toString) + val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) + assert(appArgs.propertiesFile != null) + assert(appArgs.propertiesFile.startsWith(path)) + appArgs.executorMemory should be ("2.3g") + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String]): String = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -314,6 +330,22 @@ class SparkSubmitSuite extends FunSuite with Matchers { new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) } + + def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { + val tmpDir = Files.createTempDir() + + val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") + val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) + for ((key, value) <- defaults) writer.write(s"$key $value\n") + + writer.close() + + try { + f(tmpDir.getAbsolutePath) + } finally { + Utils.deleteRecursively(tmpDir) + } + } } object JarCreationTest { diff --git a/docs/configuration.md b/docs/configuration.md index 316490f0f43fc..a782809a55ec0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1108,3 +1108,10 @@ compute `SPARK_LOCAL_IP` by looking up the IP of a specific network interface. Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can configure it by adding a `log4j.properties` file in the `conf` directory. One way to start is to copy the existing `log4j.properties.template` located there. + +# Overriding configuration directory + +To specify a different configuration directory other than the default "SPARK_HOME/conf", +you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) +from this directory. + From 9d320e222c221e5bb827cddf01a83e64a16d74ff Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Fri, 3 Oct 2014 10:42:41 -0700 Subject: [PATCH 179/315] [SPARK-3696]Do not override the user-difined conf_dir https://issues.apache.org/jira/browse/SPARK-3696 We see if SPARK_CONF_DIR is already defined before assignment. Author: WangTaoTheTonic Closes #2541 from WangTaoTheTonic/confdir and squashes the following commits: c3f31e0 [WangTaoTheTonic] Do not override the user-difined conf_dir --- sbin/spark-config.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 2718d6cba1c9a..1d154e62ed5b6 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -33,7 +33,7 @@ this="$config_bin/$script" export SPARK_PREFIX="`dirname "$this"`"/.. export SPARK_HOME="${SPARK_PREFIX}" -export SPARK_CONF_DIR="$SPARK_HOME/conf" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" # Add the PySpark classes to the PYTHONPATH: export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" From 22f8e1ee7c4ea7b3bd4c6faaf0fe5b88a134ae12 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Fri, 3 Oct 2014 11:25:18 -0700 Subject: [PATCH 180/315] [SPARK-2693][SQL] Supported for UDAF Hive Aggregates like PERCENTILE Implemented UDAF Hive aggregates by adding wrapper to Spark Hive. Author: ravipesala Closes #2620 from ravipesala/SPARK-2693 and squashes the following commits: a8df326 [ravipesala] Removed resolver from constructor arguments caf25c6 [ravipesala] Fixed style issues 5786200 [ravipesala] Supported for UDAF Hive Aggregates like PERCENTILE --- .../org/apache/spark/sql/hive/hiveUdfs.scala | 46 +++++++++++++++++-- .../sql/hive/execution/HiveUdfSuite.scala | 4 ++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 732e4976f6843..68f93f247d9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ @@ -57,7 +57,8 @@ private[hive] abstract class HiveFunctionRegistry } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdaf(functionClassName, children) - + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUdaf(functionClassName, children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdtf(functionClassName, Nil, children) } else { @@ -194,6 +195,37 @@ private[hive] case class HiveGenericUdaf( def newInstance() = new HiveUdafFunction(functionClassName, children, this) } +/** It is used as a wrapper for the hive functions which uses UDAF interface */ +private[hive] case class HiveUdaf( + functionClassName: String, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors + with HiveFunctionFactory { + + type UDFType = UDAF + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction()) + + @transient + protected lazy val objectInspector = { + resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + @transient + protected lazy val inspectors = children.map(_.dataType).map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" + + def newInstance() = + new HiveUdafFunction(functionClassName, children, this, true) +} + /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow @@ -275,14 +307,20 @@ private[hive] case class HiveGenericUdtf( private[hive] case class HiveUdafFunction( functionClassName: String, exprs: Seq[Expression], - base: AggregateExpression) + base: AggregateExpression, + isUDAFBridgeRequired: Boolean = false) extends AggregateFunction with HiveInspectors with HiveFunctionFactory { def this() = this(null, null, null) - private val resolver = createFunction[AbstractGenericUDAFResolver]() + private val resolver = + if (isUDAFBridgeRequired) { + new GenericUDAFBridge(createFunction[UDAF]()) + } else { + createFunction[AbstractGenericUDAFResolver]() + } private val inspectors = exprs.map(_.dataType).map(toInspector).toArray diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index cc125d539c3c2..e4324e9528f9b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -79,6 +79,10 @@ class HiveUdfSuite extends HiveComparisonTest { sql("SELECT testUdf(pair) FROM hiveUdfTestTable") sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + + test("SPARK-2693 udaf aggregates test") { + assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { From fbe8e9856b23262193105e7bf86075f516f0db25 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 3 Oct 2014 11:36:24 -0700 Subject: [PATCH 181/315] [SPARK-2778] [yarn] Add workaround for race in MiniYARNCluster. Sometimes the cluster's start() method returns before the configuration having been updated, which is done by ClientRMService in, I assume, a separate thread (otherwise there would be no race). That can cause tests to fail if the old configuration data is read, since it will contain the wrong RM address. Author: Marcelo Vanzin Closes #2605 from vanzin/SPARK-2778 and squashes the following commits: 8d02ce0 [Marcelo Vanzin] Minor cleanup. 5bebee7 [Marcelo Vanzin] [SPARK-2778] [yarn] Add workaround for race in MiniYARNCluster. --- .../spark/deploy/yarn/YarnClusterSuite.scala | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 4b6635679f053..a826b2a78a8f5 100644 --- a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ @@ -32,7 +33,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { +class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the Yarn containers, so that their output is collected // by Yarn instead of trying to overwrite unit-tests.log. @@ -66,7 +67,33 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) yarnCluster.init(new YarnConfiguration()) yarnCluster.start() - yarnCluster.getConfig().foreach { e => + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + config.foreach { e => sys.props += ("spark.hadoop." + e.getKey() -> e.getValue()) } @@ -86,13 +113,13 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { super.afterAll() } - ignore("run Spark in yarn-client mode") { + test("run Spark in yarn-client mode") { var result = File.createTempFile("result", null, tempDir) YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) checkResult(result) } - ignore("run Spark in yarn-cluster mode") { + test("run Spark in yarn-cluster mode") { val main = YarnClusterDriver.getClass.getName().stripSuffix("$") var result = File.createTempFile("result", null, tempDir) From bec0d0eaa33811fde72b84f7d53a6f6031e7b5d3 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 3 Oct 2014 12:26:02 -0700 Subject: [PATCH 182/315] [SPARK-3007][SQL] Adds dynamic partitioning support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #2226 was reverted because it broke Jenkins builds for unknown reason. This debugging PR aims to fix the Jenkins build. This PR also fixes two bugs: 1. Compression configurations in `InsertIntoHiveTable` are disabled by mistake The `FileSinkDesc` object passed to the writer container doesn't have compression related configurations. These configurations are not taken care of until `saveAsHiveFile` is called. This PR moves compression code forward, right after instantiation of the `FileSinkDesc` object. 1. `PreInsertionCasts` doesn't take table partitions into account In `castChildOutput`, `table.attributes` only contains non-partition columns, thus for partitioned table `childOutputDataTypes` never equals to `tableOutputDataTypes`. This results funny analyzed plan like this: ``` == Analyzed Logical Plan == InsertIntoTable Map(partcol1 -> None, partcol2 -> None), false MetastoreRelation default, dynamic_part_table, None Project [c_0#1164,c_1#1165,c_2#1166] Project [c_0#1164,c_1#1165,c_2#1166] Project [c_0#1164,c_1#1165,c_2#1166] ... (repeats 99 times) ... Project [c_0#1164,c_1#1165,c_2#1166] Project [c_0#1164,c_1#1165,c_2#1166] Project [1 AS c_0#1164,1 AS c_1#1165,1 AS c_2#1166] Filter (key#1170 = 150) MetastoreRelation default, src, None ``` Awful though this logical plan looks, it's harmless because all projects will be eliminated by optimizer. Guess that's why this issue hasn't been caught before. Author: Cheng Lian Author: baishuo(白硕) Author: baishuo Closes #2616 from liancheng/dp-fix and squashes the following commits: 21935b6 [Cheng Lian] Adds back deleted trailing space f471c4b [Cheng Lian] PreInsertionCasts should take table partitions into account a132c80 [Cheng Lian] Fixes output compression 9c6eb2d [Cheng Lian] Adds tests to verify dynamic partitioning folder layout 0eed349 [Cheng Lian] Addresses @yhuai's comments 26632c3 [Cheng Lian] Adds more tests 9227181 [Cheng Lian] Minor refactoring c47470e [Cheng Lian] Refactors InsertIntoHiveTable to a Command 6fb16d7 [Cheng Lian] Fixes typo in test name, regenerated golden answer files d53daa5 [Cheng Lian] Refactors dynamic partitioning support b821611 [baishuo] pass check style 997c990 [baishuo] use HiveConf.DEFAULTPARTITIONNAME to replace hive.exec.default.partition.name 761ecf2 [baishuo] modify according micheal's advice 207c6ac [baishuo] modify for some bad indentation caea6fb [baishuo] modify code to pass scala style checks b660e74 [baishuo] delete a empty else branch cd822f0 [baishuo] do a little modify 8e7268c [baishuo] update file after test 3f91665 [baishuo(白硕)] Update Cast.scala 8ad173c [baishuo(白硕)] Update InsertIntoHiveTable.scala 051ba91 [baishuo(白硕)] Update Cast.scala d452eb3 [baishuo(白硕)] Update HiveQuerySuite.scala 37c603b [baishuo(白硕)] Update InsertIntoHiveTable.scala 98cfb1f [baishuo(白硕)] Update HiveCompatibilitySuite.scala 6af73f4 [baishuo(白硕)] Update InsertIntoHiveTable.scala adf02f1 [baishuo(白硕)] Update InsertIntoHiveTable.scala 1867e23 [baishuo(白硕)] Update SparkHadoopWriter.scala 6bb5880 [baishuo(白硕)] Update HiveQl.scala --- .../execution/HiveCompatibilitySuite.scala | 17 ++ .../org/apache/spark/SparkHadoopWriter.scala | 195 ---------------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 +- .../org/apache/spark/sql/hive/HiveQl.scala | 5 - .../hive/execution/InsertIntoHiveTable.scala | 218 ++++++++++-------- .../spark/sql/hive/hiveWriterContainers.scala | 217 +++++++++++++++++ ...rtition-0-be33aaa7253c8f248ff3921cd7dae340 | 0 ...rtition-1-640552dd462707563fd255a713f83b41 | 0 ...rtition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 | 1 + ...rtition-3-b7f7fa7ebf666f4fee27e149d8c6961f | 0 ...rtition-4-8bdb71ad8cb3cc3026043def2525de3a | 0 ...rtition-5-c630dce438f3792e7fb0f523fbbb3e1e | 0 ...rtition-6-7abc9ec8a36cdc5e89e955265a7fd7cf | 0 ...rtition-7-be33aaa7253c8f248ff3921cd7dae340 | 0 .../sql/hive/execution/HiveQuerySuite.scala | 100 +++++++- 15 files changed, 450 insertions(+), 306 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 556c984ad392b..35e9c9939d4b7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -220,6 +220,23 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { */ override def whiteList = Seq( "add_part_exist", + "dynamic_partition_skip_default", + "infer_bucket_sort_dyn_part", + "load_dyn_part1", + "load_dyn_part2", + "load_dyn_part3", + "load_dyn_part4", + "load_dyn_part5", + "load_dyn_part6", + "load_dyn_part7", + "load_dyn_part8", + "load_dyn_part9", + "load_dyn_part10", + "load_dyn_part11", + "load_dyn_part12", + "load_dyn_part13", + "load_dyn_part14", + "load_dyn_part14_win", "add_part_multiple", "add_partition_no_whitelist", "add_partition_with_whitelist", diff --git a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala deleted file mode 100644 index ab7862f4f9e06..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException -import java.text.NumberFormat -import java.util.Date - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} -import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} -import org.apache.hadoop.hive.ql.plan.FileSinkDesc -import org.apache.hadoop.mapred._ -import org.apache.hadoop.io.Writable - -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} - -/** - * Internal helper class that saves an RDD using a Hive OutputFormat. - * It is based on [[SparkHadoopWriter]]. - */ -private[hive] class SparkHiveHadoopWriter( - @transient jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { - - private val now = new Date() - private val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private var format: HiveOutputFormat[AnyRef, Writable] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null - - def preSetup() { - setIDs(0, 0, 0) - setConfParams() - - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) - } - - - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - setConfParams() - } - - def open() { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val extension = Utilities.getFileExtension( - conf.value, - fileSinkConf.getCompressed, - getOutputFormat()) - - val outputName = "part-" + numfmt.format(splitID) + extension - val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) - - getOutputCommitter().setupTask(getTaskContext()) - writer = HiveFileFormatUtils.getHiveRecordWriter( - conf.value, - fileSinkConf.getTableInfo, - conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], - fileSinkConf, - path, - null) - } - - def write(value: Writable) { - if (writer != null) { - writer.write(value) - } else { - throw new IOException("Writer is null, open() has not been called") - } - } - - def close() { - // Seems the boolean value passed into close does not matter. - writer.close(false) - } - - def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } else { - logWarning ("No need to commit output of task: " + taID.value) - } - } - - def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) - } - - // ********* Private Functions ********* - - private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] - } - format - } - - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter - } - committer - } - - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) - } - jobContext - } - - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) - } - taskContext - } - - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { - jobID = jobId - splitID = splitId - attemptID = attemptId - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -private[hive] object SparkHiveHadoopWriter { - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 06b1446ccbd39..989a9784a438d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -144,7 +144,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val childOutputDataTypes = child.output.map(_.dataType) // Only check attributes, not partitionKeys since they are always strings. // TODO: Fully support inserting into partitioned tables. - val tableOutputDataTypes = table.attributes.map(_.dataType) + val tableOutputDataTypes = + table.attributes.map(_.dataType) ++ table.partitionKeys.map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { p diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 6bb42eeb0550d..32c9175f181bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -806,11 +806,6 @@ private[hive] object HiveQl { cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - if (partitionKeys.values.exists(p => p.isEmpty)) { - throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + - s"dynamic partitioning.") - } - InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) case a: ASTNode => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index a284a91a91e31..16a8c782acdfa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import java.util.{HashMap => JHashMap} - import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector -import org.apache.hadoop.io.Writable +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{JavaHiveDecimalObjectInspector, JavaHiveVarcharObjectInspector} import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter} +import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} +import org.apache.spark.sql.hive._ +import org.apache.spark.{SerializableWritable, SparkException, TaskContext} /** * :: DeveloperApi :: @@ -51,7 +49,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode { + extends UnaryNode with Command { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -101,66 +99,61 @@ case class InsertIntoHiveTable( } def saveAsHiveFile( - rdd: RDD[Writable], + rdd: RDD[Row], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: JobConf, - isCompressed: Boolean) { - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - conf.setOutputValueClass(valueClass) - if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { - throw new SparkException("Output format class not set") - } - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) - if (isCompressed) { - // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", - // and "mapred.output.compression.type" have no impact on ORC because it uses table properties - // to store compression information. - conf.set("mapred.output.compress", "true") - fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) - } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath( - conf, - SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + conf: SerializableWritable[JobConf], + writerContainer: SparkHiveWriterContainer) { + assert(valueClass != null, "Output value class not set") + conf.value.setOutputValueClass(valueClass) + + val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName + assert(outputFileFormatClassName != null, "Output format class not set") + conf.value.set("mapred.output.format.class", outputFileFormatClassName) + conf.value.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf.value, + SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) - writer.preSetup() + writerContainer.driverSideSetup() + sc.sparkContext.runJob(rdd, writeToFile _) + writerContainer.commitJob() + + // Note that this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) - def writeToFile(context: TaskContext, iter: Iterator[Writable]) { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber) - writer.setup(context.stageId, context.partitionId, attemptNumber) - writer.open() + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record) + val writer = writerContainer.getLocalFileWriter(row) + writer.write(serializer.serialize(outputData, standardOI)) } - writer.close() - writer.commit() + writerContainer.close() } - - sc.sparkContext.runJob(rdd, writeToFile _) - writer.commitJob() } - override def execute() = result - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -168,50 +161,69 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - private lazy val result: RDD[Row] = { - val childRdd = child.execute() - assert(childRdd != null) - + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val rdd = childRdd.mapPartitions { iter => - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] + val isCompressed = sc.hiveconf.getBoolean( + ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) + if (isCompressed) { + // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", + // and "mapred.output.compression.type" have no impact on ORC because it uses table properties + // to store compression information. + sc.hiveconf.set("mapred.output.compress", "true") + fileSinkConf.setCompressed(true) + fileSinkConf.setCompressCodec(sc.hiveconf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(sc.hiveconf.get("mapred.output.compression.type")) + } - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val outputData = new Array[Any](fieldOIs.length) - iter.map { row => - var i = 0 - while (i < row.length) { - // Casts Strings to HiveVarchars when necessary. - outputData(i) = wrap(row(i), fieldOIs(i)) - i += 1 - } + val numDynamicPartitions = partition.values.count(_.isEmpty) + val numStaticPartitions = partition.values.count(_.nonEmpty) + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" + } + + // All partition column names in the format of "//..." + val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull + + // Validate partition spec if there exist any dynamic partitions + if (numDynamicPartitions > 0) { + // Report error if dynamic partitioning is not enabled + if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) + } - serializer.serialize(outputData, standardOI) + // Report error if dynamic partition strict mode is on but no static partition is found + if (numStaticPartitions == 0 && + sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) + } + + // Report error if any static partition appears after a dynamic partition + val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) + isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } - // ORC stores compression information in table properties. While, there are other formats - // (e.g. RCFile) that rely on hadoop configurations to store compression information. val jobConf = new JobConf(sc.hiveconf) - saveAsHiveFile( - rdd, - outputClass, - fileSinkConf, - jobConf, - sc.hiveconf.getBoolean("hive.exec.compress.output", false)) - - // TODO: Handle dynamic partitioning. + val jobConfSer = new SerializableWritable(jobConf) + + val writerContainer = if (numDynamicPartitions > 0) { + val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) + new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + } else { + new SparkHiveWriterContainer(jobConf, fileSinkConf) + } + + saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) + val outputPath = FileOutputFormat.getOutputPath(jobConf) // Have to construct the format of dbname.tablename. val qualifiedTableName = s"${table.databaseName}.${table.tableName}" @@ -220,10 +232,6 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { - val partitionSpec = partition.map { - case (key, Some(value)) => key -> value - case (key, None) => key -> "" // Should not reach here right now. - } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -231,14 +239,26 @@ case class InsertIntoHiveTable( val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false - db.loadPartition( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + if (numDynamicPartitions > 0) { + db.loadDynamicPartitions( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir + ) + } else { + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } } else { db.loadTable( outputPath, @@ -251,6 +271,6 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - sc.sparkContext.makeRDD(Nil, 1) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala new file mode 100644 index 0000000000000..ac5c7a8220296 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred._ + +import org.apache.spark.sql.Row +import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +private[hive] class SparkHiveWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + protected val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private lazy val committer = conf.value.getOutputCommitter + @transient private lazy val jobContext = newJobContext(conf.value, jID.value) + @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient private lazy val outputFormat = + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + + def driverSideSetup() { + setIDs(0, 0, 0) + setConfParams() + committer.setupJob(jobContext) + } + + def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) { + setIDs(jobId, splitId, attemptId) + setConfParams() + committer.setupTask(taskContext) + initWriters() + } + + protected def getOutputName: String = { + val numberFormat = NumberFormat.getInstance() + numberFormat.setMinimumIntegerDigits(5) + numberFormat.setGroupingUsed(false) + val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) + "part-" + numberFormat.format(splitID) + extension + } + + def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + commit() + } + + def commitJob() { + committer.commitJob(jobContext) + } + + protected def initWriters() { + // NOTE this method is executed at the executor side. + // For Hive tables without partitions or with only static partitions, only 1 writer is needed. + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + FileOutputFormat.getTaskOutputPath(conf.value, getOutputName), + Reporter.NULL) + } + + protected def commit() { + if (committer.needsTaskCommit(taskContext)) { + try { + committer.commitTask(taskContext) + logInfo (taID + ": Committed") + } catch { + case e: IOException => + logError("Error committing the output of task: " + taID.value, e) + committer.abortTask(taskContext) + throw e + } + } else { + logInfo("No need to commit output of task: " + taID.value) + } + } + + // ********* Private Functions ********* + + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { + jobID = jobId + splitID = splitId + attemptID = attemptId + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +private[hive] object SparkHiveWriterContainer { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } +} + +private[spark] class SparkHiveDynamicPartitionWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc, + dynamicPartColNames: Array[String]) + extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + + private val defaultPartName = jobConf.get( + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + + @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ + + override protected def initWriters(): Unit = { + // NOTE: This method is executed at the executor side. + // Actual writers are created for each dynamic partition on the fly. + writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + } + + override def close(): Unit = { + writers.values.foreach(_.close(false)) + commit() + } + + override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + val dynamicPartPath = dynamicPartColNames + .zip(row.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = if (rawVal == null) null else String.valueOf(rawVal) + s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" + } + .mkString + + def newWriter = { + val newFileSinkDesc = new FileSinkDesc( + fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getTableInfo, + fileSinkConf.getCompressed) + newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) + newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) + + val path = { + val outputPath = FileOutputFormat.getOutputPath(conf.value) + assert(outputPath != null, "Undefined job output-path") + val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) + new Path(workPath, getOutputName) + } + + HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + newFileSinkDesc, + path, + Reporter.NULL) + } + + writers.getOrElseUpdate(dynamicPartPath, newWriter) + } +} diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 b/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f b/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a b/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e b/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf b/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f5868bff22f13..2e282a9ade40c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.SparkException import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -380,7 +383,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.exists(_ == "== Physical Plan ==") + explanation.contains("== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -568,6 +571,91 @@ class HiveQuerySuite extends HiveComparisonTest { case class LogEntry(filename: String, message: String) case class LogFile(name: String) + createQueryTest("dynamic_partition", + """ + |DROP TABLE IF EXISTS dynamic_part_table; + |CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT); + | + |SET hive.exec.dynamic.partition.mode=nonstrict; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, NULL FROM src WHERE key=150; + | + |INSERT INTO TABLe dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, NULL FROM src WHERE key=150; + | + |DROP TABLE IF EXISTS dynamic_part_table; + """.stripMargin) + + test("Dynamic partition folder layout") { + sql("DROP TABLE IF EXISTS dynamic_part_table") + sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + val data = Map( + Seq("1", "1") -> 1, + Seq("1", "NULL") -> 2, + Seq("NULL", "1") -> 3, + Seq("NULL", "NULL") -> 4) + + data.foreach { case (parts, value) => + sql( + s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 + """.stripMargin) + + val partFolder = Seq("partcol1", "partcol2") + .zip(parts) + .map { case (k, v) => + if (v == "NULL") { + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + } else { + s"$k=$v" + } + } + .mkString("/") + + // Loads partition data to a temporary table to verify contents + val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" + + sql("DROP TABLE IF EXISTS dp_verify") + sql("CREATE TABLE dp_verify(intcol INT)") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") + + assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + } + } + + test("Partition spec validation") { + sql("DROP TABLE IF EXISTS dp_test") + sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") + sql("SET hive.exec.dynamic.partition.mode=strict") + + // Should throw when using strict dynamic partition mode without any static partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + // Should throw when a static partition appears after a dynamic partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + } + test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") @@ -625,27 +713,27 @@ class HiveQuerySuite extends HiveComparisonTest { assert(sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + collectResults(sql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + collectResults(sql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + collectResults(sql("SET")) } // "set key" assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + collectResults(sql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + collectResults(sql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as sql() by repeating the above using sql(). From 6a1d48f4f02c4498b64439c3dd5f671286a90e30 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 3 Oct 2014 12:34:27 -0700 Subject: [PATCH 183/315] [SPARK-3212][SQL] Use logical plan matching instead of temporary tables for table caching _Also addresses: SPARK-1671, SPARK-1379 and SPARK-3641_ This PR introduces a new trait, `CacheManger`, which replaces the previous temporary table based caching system. Instead of creating a temporary table that shadows an existing table with and equivalent cached representation, the cached manager maintains a separate list of logical plans and their cached data. After optimization, this list is searched for any matching plan fragments. When a matching plan fragment is found it is replaced with the cached data. There are several advantages to this approach: - Calling .cache() on a SchemaRDD now works as you would expect, and uses the more efficient columnar representation. - Its now possible to provide a list of temporary tables, without having to decide if a given table is actually just a cached persistent table. (To be done in a follow-up PR) - In some cases it is possible that cached data will be used, even if a cached table was not explicitly requested. This is because we now look at the logical structure instead of the table name. - We now correctly invalidate when data is inserted into a hive table. Author: Michael Armbrust Closes #2501 from marmbrus/caching and squashes the following commits: 63fbc2c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching. 0ea889e [Michael Armbrust] Address comments. 1e23287 [Michael Armbrust] Add support for cache invalidation for hive inserts. 65ed04a [Michael Armbrust] fix tests. bdf9a3f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching b4b77f2 [Michael Armbrust] Address comments 6923c9d [Michael Armbrust] More comments / tests 80f26ac [Michael Armbrust] First draft of improved semantics for Spark SQL caching. --- .../sql/catalyst/analysis/Analyzer.scala | 3 + .../expressions/namedExpressions.scala | 4 +- .../catalyst/plans/logical/LogicalPlan.scala | 42 ++++++ .../catalyst/plans/logical/TestRelation.scala | 6 + .../plans/logical/basicOperators.scala | 4 +- .../sql/catalyst/plans/SameResultSuite.scala | 62 ++++++++ .../org/apache/spark/sql/CacheManager.scala | 139 ++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 51 +------ .../org/apache/spark/sql/SchemaRDD.scala | 23 ++- .../org/apache/spark/sql/SchemaRDDLike.scala | 5 +- .../spark/sql/api/java/JavaSQLContext.scala | 10 +- .../columnar/InMemoryColumnarTableScan.scala | 28 +++- .../spark/sql/execution/ExistingRDD.scala | 119 +++++++++++++++ .../spark/sql/execution/SparkPlan.scala | 33 ----- .../spark/sql/execution/SparkStrategies.scala | 9 +- .../spark/sql/execution/basicOperators.scala | 39 ----- .../apache/spark/sql/CachedTableSuite.scala | 103 +++++++------ .../columnar/InMemoryColumnarQuerySuite.scala | 7 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 +- .../spark/sql/hive/HiveStrategies.scala | 6 +- .../org/apache/spark/sql/hive/TestHive.scala | 5 +- .../hive/execution/InsertIntoHiveTable.scala | 3 + .../spark/sql/hive/CachedTableSuite.scala | 100 ++++++++----- 23 files changed, 567 insertions(+), 241 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 71810b798bd04..fe83eb12502dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -93,6 +93,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object ResolveRelations extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) => + i.copy( + table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias))) case UnresolvedRelation(databaseName, name, alias) => catalog.lookupRelation(databaseName, name, alias) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 59fb0311a9c44..e5a958d599393 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression { def withName(newName: String): Attribute def toAttribute = this - def newInstance: Attribute + def newInstance(): Attribute } @@ -131,7 +131,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea h } - override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 28d863e58beca..4f8ad8a7e0223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees @@ -72,6 +73,47 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = !children.exists(!_.resolved) + /** + * Returns true when the given logical plan will return the same results as this logical plan. + * + * Since its likely undecideable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually + * the same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * By default this function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. Logical operators that + * can do better should override this function. + */ + def sameResult(plan: LogicalPlan): Boolean = { + plan.getClass == this.getClass && + plan.children.size == children.size && { + logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") + cleanArgs == plan.cleanArgs + } && + (plan.children, children).zipped.forall(_ sameResult _) + } + + /** Args that have cleaned such that differences in expression id should not affect equality */ + protected lazy val cleanArgs: Seq[Any] = { + val input = children.flatMap(_.output) + productIterator.map { + // Children are checked using sameResult above. + case tn: TreeNode[_] if children contains tn => null + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case s: Option[_] => s.map { + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case other => other + } + case s: Seq[_] => s.map { + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case other => other + } + case other => other + }.toSeq + } + /** * Optionally resolves the given string to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala index f8fe558511bfd..19769986ef58c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -41,4 +41,10 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) } override protected def stringArgs = Iterator(output) + + override def sameResult(plan: LogicalPlan): Boolean = plan match { + case LocalRelation(otherOutput, otherData) => + otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 391508279bb80..f8e9930ac270d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -105,8 +105,8 @@ case class InsertIntoTable( child: LogicalPlan, overwrite: Boolean) extends LogicalPlan { - // The table being inserted into is a child for the purposes of transformations. - override def children = table :: child :: Nil + + override def children = child :: Nil override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala new file mode 100644 index 0000000000000..e8a793d107451 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.util._ + +/** + * Provides helper methods for comparing plans. + */ +class SameResultSuite extends FunSuite { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) + + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = { + val aAnalyzed = a.analyze + val bAnalyzed = b.analyze + + if (aAnalyzed.sameResult(bAnalyzed) != result) { + val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n") + fail(s"Plans should return sameResult = $result\n$comparison") + } + } + + test("relations") { + assertSameResult(testRelation, testRelation2) + } + + test("projections") { + assertSameResult(testRelation.select('a), testRelation2.select('a)) + assertSameResult(testRelation.select('b), testRelation2.select('b)) + assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b)) + assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a)) + + assertSameResult(testRelation, testRelation2.select('a), false) + assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), false) + } + + test("filters") { + assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala new file mode 100644 index 0000000000000..aebdbb68e49b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.locks.ReentrantReadWriteLock + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StorageLevel.MEMORY_ONLY + +/** Holds a cached logical plan and its data */ +private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) + +/** + * Provides support in a SQLContext for caching query results and automatically using these cached + * results when subsequent queries are executed. Data is cached using byte buffers stored in an + * InMemoryRelation. This relation is automatically substituted query plans that return the + * `sameResult` as the originally cached query. + */ +private[sql] trait CacheManager { + self: SQLContext => + + @transient + private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + + @transient + private val cacheLock = new ReentrantReadWriteLock + + /** Returns true if the table is currently cached in-memory. */ + def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = cacheQuery(table(tableName)) + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName)) + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val lock = cacheLock.readLock() + lock.lock() + try f finally { + lock.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val lock = cacheLock.writeLock() + lock.lock() + try f finally { + lock.unlock() + } + } + + private[sql] def clearCache(): Unit = writeLock { + cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.clear() + } + + /** Caches the data produced by the logical representation of the given schema rdd. */ + private[sql] def cacheQuery( + query: SchemaRDD, + storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + val planToCache = query.queryExecution.optimizedPlan + if (lookupCachedData(planToCache).nonEmpty) { + logWarning("Asked to cache already cached data.") + } else { + cachedData += + CachedData( + planToCache, + InMemoryRelation( + useCompression, columnBatchSize, storageLevel, query.queryExecution.executedPlan)) + } + } + + /** Removes the data for the given SchemaRDD from the cache */ + private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = false): Unit = writeLock { + val planToCache = query.queryExecution.optimizedPlan + val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache)) + + if (dataIndex < 0) { + throw new IllegalArgumentException(s"Table $query is not cached.") + } + + cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData.remove(dataIndex) + } + + + /** Optionally returns cached data for the given SchemaRDD */ + private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + lookupCachedData(query.queryExecution.optimizedPlan) + } + + /** Optionally returns cached data for the given LogicalPlan. */ + private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { + cachedData.find(_.plan.sameResult(plan)) + } + + /** Replaces segments of the given logical plan with cached versions where possible. */ + private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { + plan transformDown { + case currentFragment => + lookupCachedData(currentFragment) + .map(_.cachedRepresentation.withOutput(currentFragment.output)) + .getOrElse(currentFragment) + } + } + + /** + * Invalidates the cache of any data that contains `plan`. Note that it is possible that this + * function will over invalidate. + */ + private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock { + cachedData.foreach { + case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => + data.cachedRepresentation.recache() + case _ => + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a42bedbe6c04e..7a55c5bf97a71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -50,6 +50,7 @@ import org.apache.spark.{Logging, SparkContext} class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging with SQLConf + with CacheManager with ExpressionConversions with UDFRegistration with Serializable { @@ -96,7 +97,8 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) + new SchemaRDD(this, + LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) } /** @@ -133,7 +135,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self) + val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) new SchemaRDD(this, logicalPlan) } @@ -272,45 +274,6 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) - /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = { - val currentTable = table(tableName).queryExecution.analyzed - val asInMemoryRelation = currentTable match { - case _: InMemoryRelation => - currentTable - - case _ => - InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) - } - - catalog.registerTable(None, tableName, asInMemoryRelation) - } - - /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = { - table(tableName).queryExecution.analyzed match { - // This is kind of a hack to make sure that if this was just an RDD registered as a table, - // we reregister the RDD as a table. - case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) => - inMem.cachedColumnBuffers.unpersist() - catalog.unregisterTable(None, tableName) - catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) - case inMem: InMemoryRelation => - inMem.cachedColumnBuffers.unpersist() - catalog.unregisterTable(None, tableName) - case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") - } - } - - /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = { - val relation = table(tableName).queryExecution.analyzed - relation match { - case _: InMemoryRelation => true - case _ => false - } - } - protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -401,10 +364,12 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) lazy val optimizedPlan = optimizer(analyzed) + lazy val withCachedData = useCachedData(optimizedPlan) + // TODO: Don't just pick the first one... lazy val sparkPlan = { SparkPlan.currentContext.set(self) - planner(optimizedPlan).next() + planner(withCachedData).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. @@ -526,6 +491,6 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) + new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3b873f7c62cb6..594bf8ffc20e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.{Map => JMap, List => JList} +import org.apache.spark.storage.StorageLevel + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.api.java.JavaRDD /** @@ -442,8 +444,7 @@ class SchemaRDD( */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { new SchemaRDD(sqlContext, - SparkLogicalPlan( - ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext)) + LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) } // ======================================================================= @@ -497,4 +498,20 @@ class SchemaRDD( override def subtract(other: RDD[Row], p: Partitioner) (implicit ord: Ordering[Row] = null): SchemaRDD = applySchema(super.subtract(other, p)(ord)) + + /** Overridden cache function will always use the in-memory columnar caching. */ + override def cache(): this.type = { + sqlContext.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheQuery(this, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.uncacheQuery(this, blocking) + this + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index e52eeb3e1c47e..25ba7d88ba538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.execution.LogicalRDD /** * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) @@ -55,8 +55,7 @@ private[sql] trait SchemaRDDLike { // For various commands (like DDL) and queries with side effects, we force query optimization to // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile => - queryExecution.toRdd - SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => baseLogicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 150ff8a42063d..c006c4330ff66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils @@ -100,7 +100,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow } } - new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext)) + new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) } /** @@ -114,7 +114,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val scalaRowRDD = rowRDD.rdd.map(r => r.row) val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] val logicalPlan = - SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(scalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -151,7 +151,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = - SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -167,7 +167,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = - SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 8a3612cdf19be..cec82a7f2df94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -27,10 +27,15 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.storage.StorageLevel private[sql] object InMemoryRelation { - def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, child)() + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) @@ -39,6 +44,7 @@ private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, + storageLevel: StorageLevel, child: SparkPlan) (private var _cachedColumnBuffers: RDD[CachedBatch] = null) extends LogicalPlan with MultiInstanceRelation { @@ -51,6 +57,16 @@ private[sql] case class InMemoryRelation( // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { + buildBuffers() + } + + def recache() = { + _cachedColumnBuffers.unpersist() + _cachedColumnBuffers = null + buildBuffers() + } + + private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { @@ -80,12 +96,17 @@ private[sql] case class InMemoryRelation( def hasNext = rowIterator.hasNext } - }.cache() + }.persist(storageLevel) cached.setName(child.toString) _cachedColumnBuffers = cached } + def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { + InMemoryRelation( + newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers) + } + override def children = Seq.empty override def newInstance() = { @@ -93,6 +114,7 @@ private[sql] case class InMemoryRelation( output.map(_.newInstance), useCompression, batchSize, + storageLevel, child)( _cachedColumnBuffers).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala new file mode 100644 index 0000000000000..2ddf513b6fc98 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object RDDConversions { + def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) + + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + i += 1 + } + + mutableRow + } + } + } + } + + /* + def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { + LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) + } + */ +} + +case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { + + def children = Nil + + def newInstance() = + LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] + + override def sameResult(plan: LogicalPlan) = plan match { + case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) +} + +case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + override def execute() = rdd +} + +@deprecated("Use LogicalRDD", "1.2.0") +case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + override def execute() = rdd +} + +@deprecated("Use LogicalRDD", "1.2.0") +case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { + + def output = alreadyPlanned.output + override def children = Nil + + override final def newInstance(): this.type = { + SparkLogicalPlan( + alreadyPlanned match { + case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case _ => sys.error("Multiple instance of the same relation detected.") + })(sqlContext).asInstanceOf[this.type] + } + + override def sameResult(plan: LogicalPlan) = plan match { + case SparkLogicalPlan(ExistingRdd(_, rdd)) => + rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2b8913985b028..b1a7948b66cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -126,39 +126,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } -/** - * :: DeveloperApi :: - * Allows already planned SparkQueries to be linked into logical query plans. - * - * Note that in general it is not valid to use this class to link multiple copies of the same - * physical operator into the same query plan as this violates the uniqueness of expression ids. - * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just - * replace the output attributes with new copies of themselves without breaking any attribute - * linking. - */ -@DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { - - def output = alreadyPlanned.output - override def children = Nil - - override final def newInstance(): this.type = { - SparkLogicalPlan( - alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) - case _ => sys.error("Multiple instance of the same relation detected.") - })(sqlContext).asInstanceOf[this.type] - } - - @transient override lazy val statistics = Statistics( - // TODO: Instead of returning a default value here, find a way to return a meaningful size - // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) - ) - -} - private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { self: Product => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 45687d960404c..cf93d5ad7b503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -272,10 +272,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => - ExistingRdd( + PhysicalRDD( output, - ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -287,12 +288,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil case logical.NoRelation => - execution.ExistingRdd(Nil, singleRowRdd) :: Nil + execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan) => existingPlan :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index cac376608be29..977f3c9f32096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -210,45 +210,6 @@ case class Sort( override def output = child.output } -/** - * :: DeveloperApi :: - */ -@DeveloperApi -object ExistingRdd { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { - data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) - i += 1 - } - - mutableRow - } - } - } - } - - def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = { - ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd -} - /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 591592841e9fe..957388e99bd85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,13 +20,30 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ case class BigData(s: String) class CachedTableSuite extends QueryTest { + import TestSQLContext._ TestData // Load test tables. + /** + * Throws a test failed exception when the number of cached tables differs from the expected + * number. + */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + if (cachedData.size != numCachedTables) { + fail( + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } + test("too big for memory") { val data = "*" * 10000 sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") @@ -35,19 +52,21 @@ class CachedTableSuite extends QueryTest { uncacheTable("bigData") } + test("calling .cache() should use inmemory columnar caching") { + table("testData").cache() + + assertCached(table("testData")) + } + test("SPARK-1669: cacheTable should be idempotent") { assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) cacheTable("testData") - table("testData").queryExecution.analyzed match { - case _: InMemoryRelation => - case _ => - fail("testData should be cached") - } + assertCached(table("testData")) cacheTable("testData") table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) => + case InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => fail("cacheTable is not idempotent") case _ => @@ -55,81 +74,69 @@ class CachedTableSuite extends QueryTest { } test("read from cached table and uncache") { - TestSQLContext.cacheTable("testData") + cacheTable("testData") checkAnswer( - TestSQLContext.table("testData"), + table("testData"), testData.collect().toSeq ) - TestSQLContext.table("testData").queryExecution.analyzed match { - case _ : InMemoryRelation => // Found evidence of caching - case noCache => fail(s"No cache node found in plan $noCache") - } + assertCached(table("testData")) - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") checkAnswer( - TestSQLContext.table("testData"), + table("testData"), testData.collect().toSeq ) - TestSQLContext.table("testData").queryExecution.analyzed match { - case cachePlan: InMemoryRelation => - fail(s"Table still cached after uncache: $cachePlan") - case noCache => // Table uncached successfully - } + assertCached(table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") } } test("SELECT Star Cached Table") { - TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar") - TestSQLContext.cacheTable("selectStar") - TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() - TestSQLContext.uncacheTable("selectStar") + sql("SELECT * FROM testData").registerTempTable("selectStar") + cacheTable("selectStar") + sql("SELECT * FROM selectStar WHERE key = 1").collect() + uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - TestSQLContext.cacheTable("testData") + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() + cacheTable("testData") checkAnswer( - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { - TestSQLContext.sql("CACHE TABLE testData") - TestSQLContext.table("testData").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => // Found evidence of caching - case _ => fail(s"Table 'testData' should be cached") - } - assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached") + sql("CACHE TABLE testData") + assertCached(table("testData")) - TestSQLContext.sql("UNCACHE TABLE testData") - TestSQLContext.table("testData").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached") - case _ => // Found evidence of uncaching - } - assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") + assert(isCached("testData"), "Table 'testData' should be cached") + + sql("UNCACHE TABLE testData") + assertCached(table("testData"), 0) + assert(!isCached("testData"), "Table 'testData' should not be cached") } test("CACHE TABLE tableName AS SELECT Star Table") { - TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() - assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestSQLContext.uncacheTable("testCacheTable") + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + sql("SELECT * FROM testCacheTable WHERE key = 1").collect() + assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + uncacheTable("testCacheTable") } test("'CACHE TABLE tableName AS SELECT ..'") { - TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestSQLContext.uncacheTable("testCacheTable") + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + uncacheTable("testCacheTable") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index c1278248ef655..9775dd26b7773 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{QueryTest, TestData} +import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -27,7 +28,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) } @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 989a9784a438d..cc0605b0adb35 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -133,11 +133,6 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => castChildOutput(p, table, child) - - case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), _, child, _) => - castChildOutput(p, table, child) } def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { @@ -306,7 +301,7 @@ private[hive] case class MetastoreRelation HiveMetastoreTypes.toDataType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true - )(qualifiers = tableName +: alias.toSeq) + )(qualifiers = Seq(alias.getOrElse(tableName))) } // Must be a stable value since new attributes are born here. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8ac17f37201a8..508d8239c7628 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.StringType -import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ @@ -161,10 +160,7 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil - case logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), partition, child, overwrite) => - InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.CreateTableAsSelect(database, tableName, child) => val query = planLater(child) CreateTableAsSelect( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 4a999b98ad92b..c0e69393cc2e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -353,7 +353,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { var cacheTables: Boolean = false def loadTestTable(name: String) { if (!(loadedTables contains name)) { - // Marks the table as loaded first to prevent infite mutually recursive table loading. + // Marks the table as loaded first to prevent infinite mutually recursive table loading. loadedTables += name logInfo(s"Loading test table $name") val createCmds = @@ -383,6 +383,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } + clearCache() loadedTables.clear() catalog.client.getAllTables("default").foreach { t => logDebug(s"Deleting table $t") @@ -428,7 +429,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logError(s"FATAL ERROR: Failed to reset TestDB state. $e") + logError("FATAL ERROR: Failed to reset TestDB state.", e) // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 16a8c782acdfa..f8b4e898ec41d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -267,6 +267,9 @@ case class InsertIntoHiveTable( holdDDLTime) } + // Invalidate the cache. + sqlContext.invalidateCache(table) + // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index b3057cd618c66..158cfb5bbee7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,22 +17,60 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.{QueryTest, SchemaRDD} import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} -import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive -class CachedTableSuite extends HiveComparisonTest { +class CachedTableSuite extends QueryTest { import TestHive._ - TestHive.loadTestTable("src") + /** + * Throws a test failed exception when the number of cached tables differs from the expected + * number. + */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + if (cachedData.size != numCachedTables) { + fail( + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } test("cache table") { - TestHive.cacheTable("src") + val preCacheResults = sql("SELECT * FROM src").collect().toSeq + + cacheTable("src") + assertCached(sql("SELECT * FROM src")) + + checkAnswer( + sql("SELECT * FROM src"), + preCacheResults) + + uncacheTable("src") + assertCached(sql("SELECT * FROM src"), 0) } - createQueryTest("read from cached table", - "SELECT * FROM src LIMIT 1", reset = false) + test("cache invalidation") { + sql("CREATE TABLE cachedTable(key INT, value STRING)") + + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + + cacheTable("cachedTable") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer( + sql("SELECT * FROM cachedTable"), + table("src").collect().toSeq ++ table("src").collect().toSeq) + + sql("DROP TABLE cachedTable") + } test("Drop cached table") { sql("CREATE TABLE test(a INT)") @@ -48,25 +86,6 @@ class CachedTableSuite extends HiveComparisonTest { sql("DROP TABLE IF EXISTS nonexistantTable") } - test("check that table is cached and uncache") { - TestHive.table("src").queryExecution.analyzed match { - case _ : InMemoryRelation => // Found evidence of caching - case noCache => fail(s"No cache node found in plan $noCache") - } - TestHive.uncacheTable("src") - } - - createQueryTest("read from uncached table", - "SELECT * FROM src LIMIT 1", reset = false) - - test("make sure table is uncached") { - TestHive.table("src").queryExecution.analyzed match { - case cachePlan: InMemoryRelation => - fail(s"Table still cached after uncache: $cachePlan") - case noCache => // Table uncached successfully - } - } - test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { TestHive.uncacheTable("src") @@ -75,23 +94,24 @@ class CachedTableSuite extends HiveComparisonTest { test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { TestHive.sql("CACHE TABLE src") - TestHive.table("src").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => // Found evidence of caching - case _ => fail(s"Table 'src' should be cached") - } + assertCached(table("src")) assert(TestHive.isCached("src"), "Table 'src' should be cached") TestHive.sql("UNCACHE TABLE src") - TestHive.table("src").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") - case _ => // Found evidence of uncaching - } + assertCached(table("src"), 0) assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } - - test("'CACHE TABLE tableName AS SELECT ..'") { - TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src") - assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestHive.uncacheTable("testCacheTable") - } + + test("CACHE TABLE AS SELECT") { + assertCached(sql("SELECT * FROM src"), 0) + sql("CACHE TABLE test AS SELECT key FROM src") + + checkAnswer( + sql("SELECT * FROM test"), + sql("SELECT key FROM src").collect().toSeq) + + assertCached(sql("SELECT * FROM test")) + + assertCached(sql("SELECT * FROM test JOIN test"), 2) + } } From a8c52d5343e19731909e73db5de151a324d31cd5 Mon Sep 17 00:00:00 2001 From: Brenden Matthews Date: Fri, 3 Oct 2014 12:58:04 -0700 Subject: [PATCH 184/315] [SPARK-3535][Mesos] Fix resource handling. Author: Brenden Matthews Closes #2401 from brndnmtthws/master and squashes the following commits: 4abaa5d [Brenden Matthews] [SPARK-3535][Mesos] Fix resource handling. --- .../mesos/CoarseMesosSchedulerBackend.scala | 7 ++-- .../scheduler/cluster/mesos/MemoryUtils.scala | 35 +++++++++++++++++++ .../cluster/mesos/MesosSchedulerBackend.scala | 34 ++++++++++++++---- docs/configuration.md | 11 ++++++ 4 files changed, 79 insertions(+), 8 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 64568409dbafd..3161f1ee9fa8a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -198,7 +198,9 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveId = offer.getSlaveId.toString val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && mem >= sc.executorMemory && cpus >= 1 && + if (totalCoresAcquired < maxCores && + mem >= MemoryUtils.calculateTotalMemory(sc) && + cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { // Launch an executor on the slave @@ -214,7 +216,8 @@ private[spark] class CoarseMesosSchedulerBackend( .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", sc.executorMemory)) + .addResources(createResource("mem", + MemoryUtils.calculateTotalMemory(sc))) .build() d.launchTasks( Collections.singleton(offer.getId), Collections.singletonList(task), filters) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala new file mode 100644 index 0000000000000..5101ec8352e79 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.spark.SparkContext + +private[spark] object MemoryUtils { + // These defaults copied from YARN + val OVERHEAD_FRACTION = 1.07 + val OVERHEAD_MINIMUM = 384 + + def calculateTotalMemory(sc: SparkContext) = { + math.max( + sc.conf.getOption("spark.mesos.executor.memoryOverhead") + .getOrElse(OVERHEAD_MINIMUM.toString) + .toInt + sc.executorMemory, + OVERHEAD_FRACTION * sc.executorMemory + ) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index a9ef126f5de0e..4c49aa074ebc0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -124,15 +124,24 @@ private[spark] class MesosSchedulerBackend( command.setValue("cd %s*; ./sbin/spark-executor".format(basename)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } + val cpus = Resource.newBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder() + .setValue(scheduler.CPUS_PER_TASK).build()) + .build() val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(sc.executorMemory).build()) + .setScalar( + Value.Scalar.newBuilder() + .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) .build() ExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) + .addResources(cpus) .addResources(memory) .build() } @@ -204,18 +213,31 @@ private[spark] class MesosSchedulerBackend( val offerableWorkers = new ArrayBuffer[WorkerOffer] val offerableIndices = new HashMap[String, Int] - def enoughMemory(o: Offer) = { + def sufficientOffer(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") + val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - mem >= sc.executorMemory || slaveIdsWithExecutors.contains(slaveId) + (mem >= MemoryUtils.calculateTotalMemory(sc) && + // need at least 1 for executor, 1 for task + cpus >= 2 * scheduler.CPUS_PER_TASK) || + (slaveIdsWithExecutors.contains(slaveId) && + cpus >= scheduler.CPUS_PER_TASK) } - for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { - offerableIndices.put(offer.getSlaveId.getValue, index) + for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) { + val slaveId = offer.getSlaveId.getValue + offerableIndices.put(slaveId, index) + val cpus = if (slaveIdsWithExecutors.contains(slaveId)) { + getResource(offer.getResourcesList, "cpus").toInt + } else { + // If the executor doesn't exist yet, subtract CPU for executor + getResource(offer.getResourcesList, "cpus").toInt - + scheduler.CPUS_PER_TASK + } offerableWorkers += new WorkerOffer( offer.getSlaveId.getValue, offer.getHostname, - getResource(offer.getResourcesList, "cpus").toInt) + cpus) } // Call into the TaskSchedulerImpl diff --git a/docs/configuration.md b/docs/configuration.md index a782809a55ec0..1c33855365170 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -253,6 +253,17 @@ Apart from these, the following properties are also available, and may be useful spark.executor.uri. + + + + +
    StructType org.apache.spark.sql.api.java + DataType.createStructType(fields)
    Note: fields is a List or an array of StructFields. Also, two fields with the same name are not allowed. @@ -1394,7 +1416,7 @@ please use factory methods provided in
    All data types of Spark SQL are located in the package of `pyspark.sql`. -You can access them by doing +You can access them by doing {% highlight python %} from pyspark.sql import * {% endhighlight %} @@ -1518,7 +1540,7 @@ from pyspark.sql import *
    StructType list or tuple + StructType(fields)
    Note: fields is a Seq of StructFields. Also, two fields with the same name are not allowed. From a9e910430fb6bb4ef1f6ae20761c43b96bb018df Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 16 Sep 2014 12:41:45 -0700 Subject: [PATCH 010/315] [SPARK-3546] InputStream of ManagedBuffer is not closed and causes running out of file descriptor Author: Kousuke Saruta Closes #2408 from sarutak/resolve-resource-leak-issue and squashes the following commits: 074781d [Kousuke Saruta] Modified SuffleBlockFetcherIterator 5f63f67 [Kousuke Saruta] Move metrics increment logic and debug logging outside try block b37231a [Kousuke Saruta] Modified FileSegmentManagedBuffer#nioByteBuffer to check null or not before invoking channel.close bf29d4a [Kousuke Saruta] Modified FileSegment to close channel --- .../org/apache/spark/network/ManagedBuffer.scala | 12 ++++++++++-- .../spark/storage/ShuffleBlockFetcherIterator.scala | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index dcecb6beeea9b..e990c1da6730f 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -19,6 +19,7 @@ package org.apache.spark.network import java.io.{FileInputStream, RandomAccessFile, File, InputStream} import java.nio.ByteBuffer +import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode import com.google.common.io.ByteStreams @@ -66,8 +67,15 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt override def size: Long = length override def nioByteBuffer(): ByteBuffer = { - val channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) + var channel: FileChannel = null + try { + channel = new RandomAccessFile(file, "r").getChannel + channel.map(MapMode.READ_ONLY, offset, length) + } finally { + if (channel != null) { + channel.close() + } + } } override def inputStream(): InputStream = { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c8e708aa6b1bc..d868758a7f549 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue -import org.apache.spark.{TaskContext, Logging, SparkException} +import org.apache.spark.{TaskContext, Logging} import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils From ec1adecbb72d291d7ef122fb0505bae53116e0e6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Sep 2014 12:51:58 -0700 Subject: [PATCH 011/315] [SPARK-3430] [PySpark] [Doc] generate PySpark API docs using Sphinx Using Sphinx to generate API docs for PySpark. requirement: Sphinx ``` $ cd python/docs/ $ make html ``` The generated API docs will be located at python/docs/_build/html/index.html It can co-exists with those generated by Epydoc. This is the first working version, after merging in, then we can continue to improve it and replace the epydoc finally. Author: Davies Liu Closes #2292 from davies/sphinx and squashes the following commits: 425a3b1 [Davies Liu] cleanup 1573298 [Davies Liu] move docs to python/docs/ 5fe3903 [Davies Liu] Merge branch 'master' into sphinx 9468ab0 [Davies Liu] fix makefile b408f38 [Davies Liu] address all comments e2ccb1b [Davies Liu] update name and version 9081ead [Davies Liu] generate PySpark API docs using Sphinx --- python/docs/Makefile | 179 ++++++++++++++++++ python/docs/conf.py | 332 ++++++++++++++++++++++++++++++++++ python/docs/epytext.py | 27 +++ python/docs/index.rst | 37 ++++ python/docs/make.bat | 242 +++++++++++++++++++++++++ python/docs/modules.rst | 7 + python/docs/pyspark.mllib.rst | 77 ++++++++ python/docs/pyspark.rst | 18 ++ python/docs/pyspark.sql.rst | 10 + python/pyspark/broadcast.py | 3 + python/pyspark/context.py | 2 +- python/pyspark/serializers.py | 3 + python/pyspark/sql.py | 12 +- 13 files changed, 944 insertions(+), 5 deletions(-) create mode 100644 python/docs/Makefile create mode 100644 python/docs/conf.py create mode 100644 python/docs/epytext.py create mode 100644 python/docs/index.rst create mode 100644 python/docs/make.bat create mode 100644 python/docs/modules.rst create mode 100644 python/docs/pyspark.mllib.rst create mode 100644 python/docs/pyspark.rst create mode 100644 python/docs/pyspark.sql.rst diff --git a/python/docs/Makefile b/python/docs/Makefile new file mode 100644 index 0000000000000..8a1324eecd325 --- /dev/null +++ b/python/docs/Makefile @@ -0,0 +1,179 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.8.2.1-src.zip) + +# User-friendly check for sphinx-build +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " xml to make Docutils-native XML files" + @echo " pseudoxml to make pseudoxml-XML files for display purposes" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + +clean: + rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/pyspark.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/pyspark.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/pyspark" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/pyspark" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +latexpdfja: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through platex and dvipdfmx..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." + +xml: + $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml + @echo + @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + +pseudoxml: + $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml + @echo + @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." diff --git a/python/docs/conf.py b/python/docs/conf.py new file mode 100644 index 0000000000000..c368cf81a003b --- /dev/null +++ b/python/docs/conf.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- +# +# pyspark documentation build configuration file, created by +# sphinx-quickstart on Thu Aug 28 15:17:47 2014. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys +import os + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath('.')) + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'epytext', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'PySpark' +copyright = u'2014, Author' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = '1.1' +# The full version, including alpha/beta/rc tags. +release = '' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +#language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ['_build'] + +# The reST default role (used for this markup: `text`) to use for all +# documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +#keep_warnings = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'default' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Add any extra paths that contain custom files (such as robots.txt or +# .htaccess) here, relative to this directory. These files are copied +# directly to the root of the documentation. +#html_extra_path = [] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = 'pysparkdoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +#'papersize': 'letterpaper', + +# The font size ('10pt', '11pt' or '12pt'). +#'pointsize': '10pt', + +# Additional stuff for the LaTeX preamble. +#'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + ('index', 'pyspark.tex', u'pyspark Documentation', + u'Author', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ('index', 'pyspark', u'pyspark Documentation', + [u'Author'], 1) +] + +# If true, show URL addresses after external links. +#man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ('index', 'pyspark', u'pyspark Documentation', + u'Author', 'pyspark', 'One line description of project.', + 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +#texinfo_appendices = [] + +# If false, no module index is generated. +#texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +#texinfo_no_detailmenu = False + + +# -- Options for Epub output ---------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = u'pyspark' +epub_author = u'Author' +epub_publisher = u'Author' +epub_copyright = u'2014, Author' + +# The basename for the epub file. It defaults to the project name. +#epub_basename = u'pyspark' + +# The HTML theme for the epub output. Since the default themes are not optimized +# for small screen space, using the same theme for HTML and epub output is +# usually not wise. This defaults to 'epub', a theme designed to save visual +# space. +#epub_theme = 'epub' + +# The language of the text. It defaults to the language option +# or en if the language is not set. +#epub_language = '' + +# The scheme of the identifier. Typical schemes are ISBN or URL. +#epub_scheme = '' + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +#epub_identifier = '' + +# A unique identification for the text. +#epub_uid = '' + +# A tuple containing the cover image and cover page html template filenames. +#epub_cover = () + +# A sequence of (type, uri, title) tuples for the guide element of content.opf. +#epub_guide = () + +# HTML files that should be inserted before the pages created by sphinx. +# The format is a list of tuples containing the path and title. +#epub_pre_files = [] + +# HTML files shat should be inserted after the pages created by sphinx. +# The format is a list of tuples containing the path and title. +#epub_post_files = [] + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] + +# The depth of the table of contents in toc.ncx. +#epub_tocdepth = 3 + +# Allow duplicate toc entries. +#epub_tocdup = True + +# Choose between 'default' and 'includehidden'. +#epub_tocscope = 'default' + +# Fix unsupported image types using the PIL. +#epub_fix_images = False + +# Scale large images. +#epub_max_image_width = 0 + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#epub_show_urls = 'inline' + +# If false, no index is generated. +#epub_use_index = True diff --git a/python/docs/epytext.py b/python/docs/epytext.py new file mode 100644 index 0000000000000..61d731bff570d --- /dev/null +++ b/python/docs/epytext.py @@ -0,0 +1,27 @@ +import re + +RULES = ( + (r"<[\w.]+>", r""), + (r"L{([\w.()]+)}", r":class:`\1`"), + (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), + (r"C{([\w.()]+)}", r":class:`\1`"), + (r"[IBCM]{(.+)}", r"`\1`"), + ('pyspark.rdd.RDD', 'RDD'), +) + +def _convert_epytext(line): + """ + >>> _convert_epytext("L{A}") + :class:`A` + """ + line = line.replace('@', ':') + for p, sub in RULES: + line = re.sub(p, sub, line) + return line + +def _process_docstring(app, what, name, obj, options, lines): + for i in range(len(lines)): + lines[i] = _convert_epytext(lines[i]) + +def setup(app): + app.connect("autodoc-process-docstring", _process_docstring) diff --git a/python/docs/index.rst b/python/docs/index.rst new file mode 100644 index 0000000000000..25b3f9bd93e63 --- /dev/null +++ b/python/docs/index.rst @@ -0,0 +1,37 @@ +.. pyspark documentation master file, created by + sphinx-quickstart on Thu Aug 28 15:17:47 2014. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to PySpark API reference! +=================================== + +Contents: + +.. toctree:: + :maxdepth: 2 + + pyspark + pyspark.sql + pyspark.mllib + + +Core classes: +--------------- + + :class:`pyspark.SparkContext` + + Main entry point for Spark functionality. + + :class:`pyspark.RDD` + + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + diff --git a/python/docs/make.bat b/python/docs/make.bat new file mode 100644 index 0000000000000..adad44fd7536a --- /dev/null +++ b/python/docs/make.bat @@ -0,0 +1,242 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/python/docs/modules.rst b/python/docs/modules.rst new file mode 100644 index 0000000000000..183564659fbcf --- /dev/null +++ b/python/docs/modules.rst @@ -0,0 +1,7 @@ +. += + +.. toctree:: + :maxdepth: 4 + + pyspark diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst new file mode 100644 index 0000000000000..e95d19e97f151 --- /dev/null +++ b/python/docs/pyspark.mllib.rst @@ -0,0 +1,77 @@ +pyspark.mllib package +===================== + +Submodules +---------- + +pyspark.mllib.classification module +----------------------------------- + +.. automodule:: pyspark.mllib.classification + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.clustering module +------------------------------- + +.. automodule:: pyspark.mllib.clustering + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.linalg module +--------------------------- + +.. automodule:: pyspark.mllib.linalg + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.random module +--------------------------- + +.. automodule:: pyspark.mllib.random + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.recommendation module +----------------------------------- + +.. automodule:: pyspark.mllib.recommendation + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.regression module +------------------------------- + +.. automodule:: pyspark.mllib.regression + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.stat module +------------------------- + +.. automodule:: pyspark.mllib.stat + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.tree module +------------------------- + +.. automodule:: pyspark.mllib.tree + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.util module +------------------------- + +.. automodule:: pyspark.mllib.util + :members: + :undoc-members: + :show-inheritance: diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst new file mode 100644 index 0000000000000..a68bd62433085 --- /dev/null +++ b/python/docs/pyspark.rst @@ -0,0 +1,18 @@ +pyspark package +=============== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 1 + + pyspark.mllib + pyspark.sql + +Contents +-------- + +.. automodule:: pyspark + :members: + :undoc-members: diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst new file mode 100644 index 0000000000000..65b3650ae10ab --- /dev/null +++ b/python/docs/pyspark.sql.rst @@ -0,0 +1,10 @@ +pyspark.sql module +================== + +Module contents +--------------- + +.. automodule:: pyspark.sql + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 5c7c9cc161dff..f124dc6c07575 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -78,6 +78,9 @@ def value(self): return self._value def unpersist(self, blocking=False): + """ + Delete cached copies of this broadcast on the executors. + """ self._jbroadcast.unpersist(blocking) os.unlink(self.path) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a33aae87f65e8..a17f2c1203d36 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -53,7 +53,7 @@ class SparkContext(object): """ Main entry point for Spark functionality. A SparkContext represents the - connection to a Spark cluster, and can be used to create L{RDD}s and + connection to a Spark cluster, and can be used to create L{RDD} and broadcast variables on that cluster. """ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ec3c6f055441d..44ac5642836e0 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -110,6 +110,9 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __repr__(self): + return "<%s object>" % self.__class__.__name__ + class FramedSerializer(Serializer): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 621a556ec6356..8f6dbab240c7b 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -289,7 +289,7 @@ class StructType(DataType): """Spark SQL StructType The data type representing rows. - A StructType object comprises a list of L{StructField}s. + A StructType object comprises a list of L{StructField}. """ @@ -904,7 +904,7 @@ class SQLContext(object): """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as + A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as tables, execute SQL over tables, cache tables, and read parquet files. """ @@ -994,7 +994,7 @@ def registerFunction(self, name, f, returnType=StringType()): str(returnType)) def inferSchema(self, rdd): - """Infer and apply a schema to an RDD of L{Row}s. + """Infer and apply a schema to an RDD of L{Row}. We peek at the first row of the RDD to determine the fields' names and types. Nested collections are supported, which include array, @@ -1047,7 +1047,7 @@ def inferSchema(self, rdd): def applySchema(self, rdd, schema): """ - Applies the given schema to the given RDD of L{tuple} or L{list}s. + Applies the given schema to the given RDD of L{tuple} or L{list}. These tuples or lists can contain complex nested structures like lists, maps or nested rows. @@ -1183,6 +1183,7 @@ def jsonFile(self, path, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( @@ -1193,6 +1194,7 @@ def jsonFile(self, path, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", @@ -1233,6 +1235,7 @@ def jsonRDD(self, rdd, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( @@ -1243,6 +1246,7 @@ def jsonRDD(self, rdd, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", From b20171267d610715d5b0a86b474c903e9bc3a1a3 Mon Sep 17 00:00:00 2001 From: Dan Osipov Date: Tue, 16 Sep 2014 13:40:16 -0700 Subject: [PATCH 012/315] [SPARK-787] Add S3 configuration parameters to the EC2 deploy scripts When deploying to AWS, there is additional configuration that is required to read S3 files. EMR creates it automatically, there is no reason that the Spark EC2 script shouldn't. This PR requires a corresponding PR to the mesos/spark-ec2 to be merged, as it gets cloned in the process of setting up machines: https://github.com/mesos/spark-ec2/pull/58 Author: Dan Osipov Closes #1120 from danosipov/s3_credentials and squashes the following commits: 758da8b [Dan Osipov] Modify documentation to include the new parameter 71fab14 [Dan Osipov] Use a parameter --copy-aws-credentials to enable S3 credential deployment 7e0da26 [Dan Osipov] Get AWS credentials out of boto connection instance 39bdf30 [Dan Osipov] Add S3 configuration parameters to the EC2 deploy scripts --- docs/ec2-scripts.md | 2 +- ec2/deploy.generic/root/spark-ec2/ec2-variables.sh | 2 ++ ec2/spark_ec2.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index f5ac6d894e1eb..b2ca6a9b48f32 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -156,6 +156,6 @@ If you have a patch or suggestion for one of these limitations, feel free to # Accessing Data in S3 -Spark's file interface allows it to process data in Amazon S3 using the same URI formats that are supported for Hadoop. You can specify a path in S3 as input through a URI of the form `s3n:///path`. You will also need to set your Amazon security credentials, either by setting the environment variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` before your program or through `SparkContext.hadoopConfiguration`. Full instructions on S3 access using the Hadoop input libraries can be found on the [Hadoop S3 page](http://wiki.apache.org/hadoop/AmazonS3). +Spark's file interface allows it to process data in Amazon S3 using the same URI formats that are supported for Hadoop. You can specify a path in S3 as input through a URI of the form `s3n:///path`. To provide AWS credentials for S3 access, launch the Spark cluster with the option `--copy-aws-credentials`. Full instructions on S3 access using the Hadoop input libraries can be found on the [Hadoop S3 page](http://wiki.apache.org/hadoop/AmazonS3). In addition to using a single input file, you can also use a directory of files as input by simply giving the path to the directory. diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh index 3570891be804e..740c267fd9866 100644 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh @@ -30,3 +30,5 @@ export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" export SWAP_MB="{{swap}}" export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" export SPARK_MASTER_OPTS="{{spark_master_opts}}" +export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" +export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" \ No newline at end of file diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 5682e96aa8770..abac71eaca595 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -158,6 +158,9 @@ def parse_args(): parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--copy-aws-credentials", action="store_true", default=False, + help="Add AWS credentials to hadoop configuration to allow Spark to access S3") (opts, args) = parser.parse_args() if len(args) != 2: @@ -714,6 +717,13 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): "spark_master_opts": opts.master_opts } + if opts.copy_aws_credentials: + template_vars["aws_access_key_id"] = conn.aws_access_key_id + template_vars["aws_secret_access_key"] = conn.aws_secret_access_key + else: + template_vars["aws_access_key_id"] = "" + template_vars["aws_secret_access_key"] = "" + # Create a temp directory in which we will place all the files to be # deployed after we substitue template parameters in them tmp_dir = tempfile.mkdtemp() From a6e1712f1e9c36deb24c5073aa8edcfc047d76eb Mon Sep 17 00:00:00 2001 From: Evan Chan Date: Tue, 16 Sep 2014 13:46:06 -0700 Subject: [PATCH 013/315] Add a Community Projects page This adds a new page to the docs listing community projects -- those created outside of Apache Spark that are of interest to the community of Spark users. Anybody can add to it just by submitting a PR. There was a discussion thread about alternatives: * Creating a Github organization for Spark projects - we could not find any sponsors for this, and it would be difficult to organize since many folks just create repos in their company organization or personal accounts * Apache has some place for storing community projects, but it was deemed difficult to work with, and again would be some permissions issues -- not everyone could update it. Author: Evan Chan Closes #2219 from velvia/community-projects-page and squashes the following commits: 7316822 [Evan Chan] Point to Spark wiki: supplemental projects page 613b021 [Evan Chan] Add a few more projects a85eaaf [Evan Chan] Add a Community Projects page --- docs/_layouts/global.html | 3 ++- docs/index.md | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index a53e8a775b71f..627ed37de4a9c 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -111,6 +111,7 @@
  • Building Spark
  • Contributing to Spark
  • +
  • Supplemental Projects
  • @@ -151,7 +152,7 @@

    {{ page.title }}

    MathJax.Hub.Config({ tex2jax: { inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ], - displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], + displayMath: [ ["$$","$$"], ["\\[", "\\]"] ], processEscapes: true, skipTags: ['script', 'noscript', 'style', 'textarea', 'pre'] } diff --git a/docs/index.md b/docs/index.md index e8ebadbd4e427..edd622ec90f64 100644 --- a/docs/index.md +++ b/docs/index.md @@ -107,6 +107,7 @@ options for deployment: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) +* [Supplemental Projects](https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects): related third party Spark projects **External Resources:** From 0a7091e689a4c8b1e7b61e9f0873e6557f40d952 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 16 Sep 2014 16:03:20 -0700 Subject: [PATCH 014/315] [SPARK-3555] Fix UISuite race condition The test "jetty selects different port under contention" is flaky. If another process binds to 4040 before the test starts, then the first server we start there will fail, and the subsequent servers we start thereafter may successfully bind to 4040 if it was released between the servers starting. Instead, we should just let Java find a random free port for us and hold onto it for the duration of the test. Author: Andrew Or Closes #2418 from andrewor14/fix-port-contention and squashes the following commits: 0cd4974 [Andrew Or] Stop them servers a7071fe [Andrew Or] Pick random port instead of 4040 --- .../test/scala/org/apache/spark/ui/UISuite.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 48790b59e7fbd..92a21f82f3c21 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -23,7 +23,6 @@ import javax.servlet.http.HttpServletRequest import scala.io.Source import scala.util.{Failure, Success, Try} -import org.eclipse.jetty.server.Server import org.eclipse.jetty.servlet.ServletContextHandler import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ @@ -108,14 +107,8 @@ class UISuite extends FunSuite { } test("jetty selects different port under contention") { - val startPort = 4040 - val server = new Server(startPort) - - Try { server.start() } match { - case Success(s) => - case Failure(e) => - // Either case server port is busy hence setup for test complete - } + val server = new ServerSocket(0) + val startPort = server.getLocalPort val serverInfo1 = JettyUtils.startJettyServer( "0.0.0.0", startPort, Seq[ServletContextHandler](), new SparkConf) val serverInfo2 = JettyUtils.startJettyServer( @@ -126,6 +119,9 @@ class UISuite extends FunSuite { assert(boundPort1 != startPort) assert(boundPort2 != startPort) assert(boundPort1 != boundPort2) + serverInfo1.server.stop() + serverInfo2.server.stop() + server.close() } test("jetty binds to port 0 correctly") { From 008a5ed4808d1467b47c1d6fa4d950cc6c4976b7 Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 16 Sep 2014 21:57:33 -0700 Subject: [PATCH 015/315] [Minor]ignore all config files in conf Some config files in ```conf``` should ignore, such as conf/fairscheduler.xml conf/hive-log4j.properties conf/metrics.properties ... So ignore all ```sh```/```properties```/```conf```/```xml``` files Author: wangfei Closes #2395 from scwf/patch-2 and squashes the following commits: 3dc53f2 [wangfei] duplicate ```conf/*.conf``` 3c2986f [wangfei] ignore all config files --- .gitignore | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 7ec8d45e12c6b..a31bf7e0091f4 100644 --- a/.gitignore +++ b/.gitignore @@ -15,11 +15,10 @@ out/ third_party/libmesos.so third_party/libmesos.dylib conf/java-opts -conf/spark-env.sh -conf/streaming-env.sh -conf/log4j.properties -conf/spark-defaults.conf -conf/hive-site.xml +conf/*.sh +conf/*.properties +conf/*.conf +conf/*.xml docs/_site docs/api target/ @@ -50,7 +49,6 @@ unit-tests.log /lib/ rat-results.txt scalastyle.txt -conf/*.conf scalastyle-output.xml # For Hive From 983609a4dd83e25598455bfce93fa1c1fa9f2c51 Mon Sep 17 00:00:00 2001 From: viper-kun Date: Wed, 17 Sep 2014 00:09:57 -0700 Subject: [PATCH 016/315] [Docs] Correct spark.files.fetchTimeout default value change the value of spark.files.fetchTimeout Author: viper-kun Closes #2406 from viper-kun/master and squashes the following commits: ecb0d46 [viper-kun] [Docs] Correct spark.files.fetchTimeout default value 7cf4c7a [viper-kun] Update configuration.md --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index af16489a44281..99faf51c6f3db 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -520,10 +520,10 @@ Apart from these, the following properties are also available, and may be useful
    spark.files.fetchTimeoutfalse60 Communication timeout to use when fetching files added through SparkContext.addFile() from - the driver. + the driver, in seconds.
    {info.name}{info.id}{info.name} {startTime} {endTime} {duration}spark.port.maxRetries 16 - Maximum number of retries when binding to a port before giving up. + Default maximum number of retries when binding to a port before giving up.
    join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. - Outer joins are also supported through leftOuterJoin and rightOuterJoin. + Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin.
    spark.python.profilefalse + Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + or it will be displayed before the driver exiting. It also can be dumped into disk by + `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, + they will not be displayed automatically before driver exiting. +
    spark.python.profile.dump(none) + The directory which is used to dump the profile result before driver exiting. + The results will be dumped as separated file for each RDD. They can be loaded + by ptats.Stats(). If this is specified, the profile result will not be displayed + automatically. +
    spark.python.worker.reuse true
    spark.python.profilefalse - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, - or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, - they will not be displayed automatically before driver exiting. -
    spark.python.profile.dump(none) - The directory which is used to dump the profile result before driver exiting. - The results will be dumped as separated file for each RDD. They can be loaded - by ptats.Stats(). If this is specified, the profile result will not be displayed - automatically. -
    spark.python.worker.reuse true
    groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    - Note: If you are grouping in order to perform an aggregation (such as a sum or + Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or combineByKey will yield much better performance.
    From a01a30927d107a8d9496f749eb9d89eda6dda9d7 Mon Sep 17 00:00:00 2001 From: shane knapp Date: Tue, 30 Sep 2014 13:11:25 -0700 Subject: [PATCH 136/315] SPARK-3745 - fix check-license to properly download and check jar for details, see: https://issues.apache.org/jira/browse/SPARK-3745 Author: shane knapp Closes #2596 from shaneknapp/SPARK-3745 and squashes the following commits: c95eea9 [shane knapp] SPARK-3745 - fix check-license to properly download and check jar --- dev/check-license | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dev/check-license b/dev/check-license index 9ff0929e9a5e8..72b1013479964 100755 --- a/dev/check-license +++ b/dev/check-license @@ -20,11 +20,10 @@ acquire_rat_jar () { - URL1="http://search.maven.org/remotecontent?filepath=org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" - URL2="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" + URL="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" JAR="$rat_jar" - + if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet if [ ! -f "$JAR" ]; then @@ -32,15 +31,17 @@ acquire_rat_jar () { printf "Attempting to fetch rat\n" JAR_DL="${JAR}.part" if hash curl 2>/dev/null; then - (curl --silent "${URL1}" > "$JAR_DL" || curl --silent "${URL2}" > "$JAR_DL") && mv "$JAR_DL" "$JAR" + curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" elif hash wget 2>/dev/null; then - (wget --quiet ${URL1} -O "$JAR_DL" || wget --quiet ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR" + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" else printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 fi fi - if [ ! -f "$JAR" ]; then + + unzip -tq $JAR &> /dev/null + if [ $? -ne 0 ]; then # We failed to download printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" exit -1 @@ -55,7 +56,7 @@ cd "$FWDIR" if test -x "$JAVA_HOME/bin/java"; then declare java_cmd="$JAVA_HOME/bin/java" -else +else declare java_cmd=java fi From d3a3840e077802647aced1ceace1494605dda1db Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Tue, 30 Sep 2014 13:28:41 -0700 Subject: [PATCH 137/315] [Build] Post commit hash with timeout messages [By request](https://github.com/apache/spark/pull/2588#issuecomment-57266871), and because it also makes sense. Author: Nicholas Chammas Closes #2597 from nchammas/timeout-commit-hash and squashes the following commits: 3d90714 [Nicholas Chammas] Revert "testing: making timeout 1 minute" 2353c95 [Nicholas Chammas] testing: making timeout 1 minute e3a477e [Nicholas Chammas] post commit hash with timeout --- dev/run-tests-jenkins | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index a6ecf3196d7d4..0b1e31b9413cf 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -141,8 +141,10 @@ function post_message () { test_result="$?" if [ "$test_result" -eq "124" ]; then - fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** after \ - a configured wait of \`${TESTS_TIMEOUT}\`." + fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** \ + for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ + after a configured wait of \`${TESTS_TIMEOUT}\`." + post_message "$fail_message" exit $test_result else From 8764fe368bbd72fe76ed318faad0e97a7279e2fe Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 30 Sep 2014 15:18:51 -0700 Subject: [PATCH 138/315] SPARK-3744 [STREAMING] FlumeStreamSuite will fail during port contention Since it looked quite easy, I took the liberty of making a quick PR that just uses `Utils.startServiceOnPort` to fix this. It works locally for me. Author: Sean Owen Closes #2601 from srowen/SPARK-3744 and squashes the following commits: ddc9319 [Sean Owen] Avoid port contention in tests by retrying several ports for Flume stream --- .../streaming/flume/FlumeStreamSuite.scala | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 6ee7ac974b4a0..33235d150b4a5 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -31,7 +31,7 @@ import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuiteBase} import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream +import org.apache.spark.util.Utils import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -41,21 +41,26 @@ import org.jboss.netty.handler.codec.compression._ class FlumeStreamSuite extends TestSuiteBase { test("flume input stream") { - runFlumeStreamTest(false, 9998) + runFlumeStreamTest(false) } test("flume input compressed stream") { - runFlumeStreamTest(true, 9997) + runFlumeStreamTest(true) } - def runFlumeStreamTest(enableDecompression: Boolean, testPort: Int) { + def runFlumeStreamTest(enableDecompression: Boolean) { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: JavaReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createStream(ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, enableDecompression) + val (flumeStream, testPort) = + Utils.startServiceOnPort(9997, (trialPort: Int) => { + val dstream = FlumeUtils.createStream( + ssc, "localhost", trialPort, StorageLevel.MEMORY_AND_DISK, enableDecompression) + (dstream, trialPort) + }) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream.receiverInputDStream, outputBuffer) + val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() @@ -63,13 +68,13 @@ class FlumeStreamSuite extends TestSuiteBase { val input = Seq(1, 2, 3, 4, 5) Thread.sleep(1000) val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)) - var client: AvroSourceProtocol = null; - + var client: AvroSourceProtocol = null + if (enableDecompression) { client = SpecificRequestor.getClient( classOf[AvroSourceProtocol], new NettyTransceiver(new InetSocketAddress("localhost", testPort), - new CompressionChannelFactory(6))); + new CompressionChannelFactory(6))) } else { client = SpecificRequestor.getClient( classOf[AvroSourceProtocol], transceiver) From 6c696d7da64e764111b680b1eee040a61f944c26 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Sep 2014 15:55:04 -0700 Subject: [PATCH 139/315] Remove compiler warning from TaskContext change. Author: Reynold Xin Closes #2602 from rxin/warning and squashes the following commits: 130186b [Reynold Xin] Remove compiler warning from TaskContext change. --- .../scala/org/apache/spark/rdd/PairRDDFunctions.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 929ded58a3bd5..0d97506450a7f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -956,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outfmt.newInstance @@ -1027,9 +1027,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt - writer.setup(context.stageId, context.partitionId, attemptNumber) + writer.setup(context.getStageId, context.getPartitionId, attemptNumber) writer.open() try { var count = 0 From d75496b1898dace4da1cf95e53c38093f8f95221 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 30 Sep 2014 17:10:36 -0700 Subject: [PATCH 140/315] [SPARK-3701][MLLIB] update python linalg api and small fixes 1. doc updates 2. simple checks on vector dimensions 3. use column major for matrices davies jkbradley Author: Xiangrui Meng Closes #2548 from mengxr/mllib-py-clean and squashes the following commits: 6dce2df [Xiangrui Meng] address comments 116b5db [Xiangrui Meng] use np.dot instead of array.dot 75f2fcc [Xiangrui Meng] fix python style fefce00 [Xiangrui Meng] better check of vector size with more tests 067ef71 [Xiangrui Meng] majored -> major ef853f9 [Xiangrui Meng] update python linalg api and small fixes --- .../apache/spark/mllib/linalg/Matrices.scala | 8 +- python/pyspark/mllib/linalg.py | 150 ++++++++++++++---- 2 files changed, 125 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 4e87fe088ecc5..2cc52e94282ba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -85,7 +85,7 @@ sealed trait Matrix extends Serializable { } /** - * Column-majored dense matrix. + * Column-major dense matrix. * The entry values are stored in a single array of doubles with columns listed in sequence. * For example, the following matrix * {{{ @@ -128,7 +128,7 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) } /** - * Column-majored sparse matrix. + * Column-major sparse matrix. * The entry values are stored in Compressed Sparse Column (CSC) format. * For example, the following matrix * {{{ @@ -207,7 +207,7 @@ class SparseMatrix( object Matrices { /** - * Creates a column-majored dense matrix. + * Creates a column-major dense matrix. * * @param numRows number of rows * @param numCols number of columns @@ -218,7 +218,7 @@ object Matrices { } /** - * Creates a column-majored sparse matrix in Compressed Sparse Column (CSC) format. + * Creates a column-major sparse matrix in Compressed Sparse Column (CSC) format. * * @param numRows number of rows * @param numCols number of columns diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 0a5dcaac55e46..51014a8ceb785 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -63,6 +63,41 @@ def _convert_to_vector(l): raise TypeError("Cannot convert type %s into Vector" % type(l)) +def _vector_size(v): + """ + Returns the size of the vector. + + >>> _vector_size([1., 2., 3.]) + 3 + >>> _vector_size((1., 2., 3.)) + 3 + >>> _vector_size(array.array('d', [1., 2., 3.])) + 3 + >>> _vector_size(np.zeros(3)) + 3 + >>> _vector_size(np.zeros((3, 1))) + 3 + >>> _vector_size(np.zeros((1, 3))) + Traceback (most recent call last): + ... + ValueError: Cannot treat an ndarray of shape (1, 3) as a vector + """ + if isinstance(v, Vector): + return len(v) + elif type(v) in (array.array, list, tuple): + return len(v) + elif type(v) == np.ndarray: + if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): + return len(v) + else: + raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape)) + elif _have_scipy and scipy.sparse.issparse(v): + assert v.shape[1] == 1, "Expected column vector" + return v.shape[0] + else: + raise TypeError("Cannot treat type %s as a vector" % type(v)) + + class Vector(object): """ Abstract class for DenseVector and SparseVector @@ -76,6 +111,9 @@ def toArray(self): class DenseVector(Vector): + """ + A dense vector represented by a value array. + """ def __init__(self, ar): if not isinstance(ar, array.array): ar = array.array('d', ar) @@ -100,15 +138,31 @@ def dot(self, other): 5.0 >>> dense.dot(np.array(range(1, 3))) 5.0 + >>> dense.dot([1.,]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F')) + array([ 5., 11.]) + >>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F')) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch """ - if isinstance(other, SparseVector): - return other.dot(self) + if type(other) == np.ndarray and other.ndim > 1: + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.toArray(), other) elif _have_scipy and scipy.sparse.issparse(other): - return other.transpose().dot(self.toArray())[0] - elif isinstance(other, Vector): - return np.dot(self.toArray(), other.toArray()) + assert len(self) == other.shape[0], "dimension mismatch" + return other.transpose().dot(self.toArray()) else: - return np.dot(self.toArray(), other) + assert len(self) == _vector_size(other), "dimension mismatch" + if isinstance(other, SparseVector): + return other.dot(self) + elif isinstance(other, Vector): + return np.dot(self.toArray(), other.toArray()) + else: + return np.dot(self.toArray(), other) def squared_distance(self, other): """ @@ -126,7 +180,16 @@ def squared_distance(self, other): >>> sparse1 = SparseVector(2, [0, 1], [2., 1.]) >>> dense1.squared_distance(sparse1) 2.0 + >>> dense1.squared_distance([1.,]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> dense1.squared_distance(SparseVector(1, [0,], [1.,])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch """ + assert len(self) == _vector_size(other), "dimension mismatch" if isinstance(other, SparseVector): return other.squared_distance(self) elif _have_scipy and scipy.sparse.issparse(other): @@ -165,12 +228,10 @@ def __getattr__(self, item): class SparseVector(Vector): - """ A simple sparse vector class for passing data to MLlib. Users may alternatively pass SciPy's {scipy.sparse} data types. """ - def __init__(self, size, *args): """ Create a sparse vector, using either a dictionary, a list of @@ -222,20 +283,33 @@ def dot(self, other): 0.0 >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]])) array([ 22., 22.]) + >>> a.dot([1., 2., 3.]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(np.array([1., 2.])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(DenseVector([1., 2.])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(np.zeros((3, 2))) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch """ if type(other) == np.ndarray: - if other.ndim == 1: - result = 0.0 - for i in xrange(len(self.indices)): - result += self.values[i] * other[self.indices[i]] - return result - elif other.ndim == 2: + if other.ndim == 2: results = [self.dot(other[:, i]) for i in xrange(other.shape[1])] return np.array(results) - else: - raise Exception("Cannot call dot with %d-dimensional array" % other.ndim) + elif other.ndim > 2: + raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim) + + assert len(self) == _vector_size(other), "dimension mismatch" - elif type(other) in (array.array, DenseVector): + if type(other) in (np.ndarray, array.array, DenseVector): result = 0.0 for i in xrange(len(self.indices)): result += self.values[i] * other[self.indices[i]] @@ -254,6 +328,7 @@ def dot(self, other): else: j += 1 return result + else: return self.dot(_convert_to_vector(other)) @@ -273,7 +348,16 @@ def squared_distance(self, other): 30.0 >>> b.squared_distance(a) 30.0 + >>> b.squared_distance([1., 2.]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> b.squared_distance(SparseVector(3, [1,], [1.0,])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch """ + assert len(self) == _vector_size(other), "dimension mismatch" if type(other) in (list, array.array, DenseVector, np.array, np.ndarray): if type(other) is np.array and other.ndim != 1: raise Exception("Cannot call squared_distance with %d-dimensional array" % @@ -348,7 +432,6 @@ def __eq__(self, other): >>> v1 != v2 False """ - return (isinstance(other, self.__class__) and other.size == self.size and other.indices == self.indices @@ -414,23 +497,32 @@ def stringify(vector): class Matrix(object): - """ the Matrix """ - def __init__(self, nRow, nCol): - self.nRow = nRow - self.nCol = nCol + """ + Represents a local matrix. + """ + + def __init__(self, numRows, numCols): + self.numRows = numRows + self.numCols = numCols def toArray(self): + """ + Returns its elements in a NumPy ndarray. + """ raise NotImplementedError class DenseMatrix(Matrix): - def __init__(self, nRow, nCol, values): - Matrix.__init__(self, nRow, nCol) - assert len(values) == nRow * nCol + """ + Column-major dense matrix. + """ + def __init__(self, numRows, numCols, values): + Matrix.__init__(self, numRows, numCols) + assert len(values) == numRows * numCols self.values = values def __reduce__(self): - return DenseMatrix, (self.nRow, self.nCol, self.values) + return DenseMatrix, (self.numRows, self.numCols, self.values) def toArray(self): """ @@ -439,10 +531,10 @@ def toArray(self): >>> arr = array.array('d', [float(i) for i in range(4)]) >>> m = DenseMatrix(2, 2, arr) >>> m.toArray() - array([[ 0., 1.], - [ 2., 3.]]) + array([[ 0., 2.], + [ 1., 3.]]) """ - return np.ndarray((self.nRow, self.nCol), np.float64, buffer=self.values.tostring()) + return np.reshape(self.values, (self.numRows, self.numCols), order='F') def _test(): From c5414b681868a0a11cc5a94184116e66e8d3e9c0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Sep 2014 18:24:57 -0700 Subject: [PATCH 141/315] [SPARK-3478] [PySpark] Profile the Python tasks This patch add profiling support for PySpark, it will show the profiling results before the driver exits, here is one example: ``` ============================================================ Profile of RDD ============================================================ 5146507 function calls (5146487 primitive calls) in 71.094 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 5144576 68.331 0.000 68.331 0.000 statcounter.py:44(merge) 20 2.735 0.137 71.071 3.554 statcounter.py:33(__init__) 20 0.017 0.001 0.017 0.001 {cPickle.dumps} 1024 0.003 0.000 0.003 0.000 t.py:16() 20 0.001 0.000 0.001 0.000 {reduce} 21 0.001 0.000 0.001 0.000 {cPickle.loads} 20 0.001 0.000 0.001 0.000 copy_reg.py:95(_slotnames) 41 0.001 0.000 0.001 0.000 serializers.py:461(read_int) 40 0.001 0.000 0.002 0.000 serializers.py:179(_batched) 62 0.000 0.000 0.000 0.000 {method 'read' of 'file' objects} 20 0.000 0.000 71.072 3.554 rdd.py:863() 20 0.000 0.000 0.001 0.000 serializers.py:198(load_stream) 40/20 0.000 0.000 71.072 3.554 rdd.py:2093(pipeline_func) 41 0.000 0.000 0.002 0.000 serializers.py:130(load_stream) 40 0.000 0.000 71.072 1.777 rdd.py:304(func) 20 0.000 0.000 71.094 3.555 worker.py:82(process) ``` Also, use can show profile result manually by `sc.show_profiles()` or dump it into disk by `sc.dump_profiles(path)`, such as ```python >>> sc._conf.set("spark.python.profile", "true") >>> rdd = sc.parallelize(range(100)).map(str) >>> rdd.count() 100 >>> sc.show_profiles() ============================================================ Profile of RDD ============================================================ 284 function calls (276 primitive calls) in 0.001 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 4 0.000 0.000 0.000 0.000 serializers.py:198(load_stream) 4 0.000 0.000 0.000 0.000 {reduce} 12/4 0.000 0.000 0.001 0.000 rdd.py:2092(pipeline_func) 4 0.000 0.000 0.000 0.000 {cPickle.loads} 4 0.000 0.000 0.000 0.000 {cPickle.dumps} 104 0.000 0.000 0.000 0.000 rdd.py:852() 8 0.000 0.000 0.000 0.000 serializers.py:461(read_int) 12 0.000 0.000 0.000 0.000 rdd.py:303(func) ``` The profiling is disabled by default, can be enabled by "spark.python.profile=true". Also, users can dump the results into disks automatically for future analysis, by "spark.python.profile.dump=path_to_dump" This is bugfix of #2351 cc JoshRosen Author: Davies Liu Closes #2556 from davies/profiler and squashes the following commits: e68df5a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 858e74c [Davies Liu] compatitable with python 2.6 7ef2aa0 [Davies Liu] bugfix, add tests for show_profiles and dump_profiles() 2b0daf2 [Davies Liu] fix docs 7a56c24 [Davies Liu] bugfix cba9463 [Davies Liu] move show_profiles and dump_profiles to SparkContext fb9565b [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 116d52a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 09d02c3 [Davies Liu] Merge branch 'master' into profiler c23865c [Davies Liu] Merge branch 'master' into profiler 15d6f18 [Davies Liu] add docs for two configs dadee1a [Davies Liu] add docs string and clear profiles after show or dump 4f8309d [Davies Liu] address comment, add tests 0a5b6eb [Davies Liu] fix Python UDF 4b20494 [Davies Liu] add profile for python --- docs/configuration.md | 19 +++++++++++++++++ python/pyspark/accumulators.py | 15 +++++++++++++ python/pyspark/context.py | 39 +++++++++++++++++++++++++++++++++- python/pyspark/rdd.py | 10 +++++++-- python/pyspark/sql.py | 2 +- python/pyspark/tests.py | 30 ++++++++++++++++++++++++++ python/pyspark/worker.py | 19 ++++++++++++++--- 7 files changed, 127 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index a6dd7245e1552..791b6f2aa3261 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks.
    spark.python.profilefalse + Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + or it will be displayed before the driver exiting. It also can be dumped into disk by + `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, + they will not be displayed automatically before driver exiting. +
    spark.python.profile.dump(none) + The directory which is used to dump the profile result before driver exiting. + The results will be dumped as separated file for each RDD. They can be loaded + by ptats.Stats(). If this is specified, the profile result will not be displayed + automatically. +
    spark.python.worker.reuse true
    StructType org.apache.spark.sql.api.java org.apache.spark.sql.api.java.Row DataType.createStructType(fields)
    Note: fields is a List or an array of StructFields. From b4fb7b80a0d863500943d788ad3e34d502a6dafa Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Thu, 2 Oct 2014 13:48:35 -0500 Subject: [PATCH 169/315] Modify default YARN memory_overhead-- from an additive constant to a multiplier Redone against the recent master branch (https://github.com/apache/spark/pull/1391) Author: Nishkam Ravi Author: nravi Author: nishkamravi2 Closes #2485 from nishkamravi2/master_nravi and squashes the following commits: 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- docs/running-on-yarn.md | 8 ++++---- .../spark/deploy/yarn/ClientArguments.scala | 16 +++++++++------- .../apache/spark/deploy/yarn/ClientBase.scala | 12 ++++++++---- .../apache/spark/deploy/yarn/YarnAllocator.scala | 16 ++++++++-------- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 8 ++++++-- 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4b3a49eca7007..695813a2ba881 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -79,16 +79,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
    spark.yarn.executor.memoryOverhead384executorMemory * 0.07, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%).
    spark.yarn.driver.memoryOverhead384driverMemory * 0.07, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%).
    spark.io.compression.codec snappy - The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, - Spark provides three codecs: lz4, lzf, and snappy. You - can also use fully qualified class names to specify the codec, e.g. - org.apache.spark.io.LZ4CompressionCodec, + The codec used to compress internal data such as RDD partitions, broadcast variables and + shuffle outputs. By default, Spark provides three codecs: lz4, lzf, + and snappy. You can also use fully qualified class names to specify the codec, + e.g. + org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, and org.apache.spark.io.SnappyCompressionCodec.
    spark.mesos.executor.memoryOverheadexecutor memory * 0.07, with minimum of 384 + This value is an additive for spark.executor.memory, specified in MiB, + which is used to calculate the total Mesos task memory. A value of 384 + implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum + overhead. The final overhead will be the larger of either + `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`. +
    #### Shuffle Behavior From 358d7ffd01b4a3fbae313890522cf662c71af6e5 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Fri, 3 Oct 2014 13:09:48 -0700 Subject: [PATCH 185/315] [SPARK-3775] Not suitable error message in spark-shell.cmd Modified some sentence of error message in bin\*.cmd. Author: Masayoshi TSUZUKI Closes #2640 from tsudukim/feature/SPARK-3775 and squashes the following commits: 3458afb [Masayoshi TSUZUKI] [SPARK-3775] Not suitable error message in spark-shell.cmd --- bin/pyspark2.cmd | 2 +- bin/run-example2.cmd | 2 +- bin/spark-class | 2 +- bin/spark-class2.cmd | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 2c4b08af8d4c3..a0e66abcc26c9 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -33,7 +33,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop* ) if [%FOUND_JAR%] == [0] ( echo Failed to find Spark assembly JAR. - echo You need to build Spark with sbt\sbt assembly before running this program. + echo You need to build Spark before running this program. goto exit ) :skip_build_test diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b29bf90c64e90..b49d0dcb4ff2d 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -52,7 +52,7 @@ if exist "%FWDIR%RELEASE" ( ) if "x%SPARK_EXAMPLES_JAR%"=="x" ( echo Failed to find Spark examples assembly JAR. - echo You need to build Spark with sbt\sbt assembly before running this program. + echo You need to build Spark before running this program. goto exit ) diff --git a/bin/spark-class b/bin/spark-class index 613dc9c4566f2..e8201c18d52de 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -146,7 +146,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 - echo "You need to build spark before running $1." 1>&2 + echo "You need to build Spark before running $1." 1>&2 exit 1 fi CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 6c5672819172b..da46543647efd 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -104,7 +104,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop* ) if "%FOUND_JAR%"=="0" ( echo Failed to find Spark assembly JAR. - echo You need to build Spark with sbt\sbt assembly before running this program. + echo You need to build Spark before running this program. goto exit ) :skip_build_test From e5566e05b1ac99aa6caf1701e47ebcdb68a002c6 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Fri, 3 Oct 2014 13:12:37 -0700 Subject: [PATCH 186/315] [SPARK-3774] typo comment in bin/utils.sh Modified the comment of bin/utils.sh. Author: Masayoshi TSUZUKI Closes #2639 from tsudukim/feature/SPARK-3774 and squashes the following commits: 707b779 [Masayoshi TSUZUKI] [SPARK-3774] typo comment in bin/utils.sh --- bin/utils.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/utils.sh b/bin/utils.sh index 0804b1ed9f231..22ea2b9a6d586 100755 --- a/bin/utils.sh +++ b/bin/utils.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# Gather all all spark-submit options into SUBMISSION_OPTS +# Gather all spark-submit options into SUBMISSION_OPTS function gatherSparkSubmitOpts() { if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then From 30abef154768e5c4c6062f3341933dbda990f6cc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 3 Oct 2014 13:18:35 -0700 Subject: [PATCH 187/315] [SPARK-3606] [yarn] Correctly configure AmIpFilter for Yarn HA. The existing code only considered one of the RMs when running in Yarn HA mode, so it was possible to get errors if the active RM was not registered in the filter. The change makes use of a new API added to Yarn that returns all proxy addresses, and falls back to the old behavior if the API is not present. While there, I also made a change to look for the scheme (http or https) being used by Yarn when building the proxy URIs. Since, in the case of multiple RMs, Yarn uses commas as a separator, it was not possible anymore to use spark.filter.params to propagate this information (which used commas to delimit different config params). Instead, I added a new param (spark.filter.jsonParams) which expects a JSON string containing a map with the config data. I chose not to add it to the documentation at this point since I don't believe users will use it directly. Author: Marcelo Vanzin Closes #2469 from vanzin/SPARK-3606 and squashes the following commits: aeb458a [Marcelo Vanzin] Undelete needed import. 65e400d [Marcelo Vanzin] Remove unused import. d121883 [Marcelo Vanzin] Use separate config for each param instead of json. 04bc156 [Marcelo Vanzin] Review feedback. 4d4d6b9 [Marcelo Vanzin] [SPARK-3606] [yarn] Correctly configure AmIpFilter for Yarn HA. --- .../cluster/CoarseGrainedClusterMessage.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 8 ++++-- .../org/apache/spark/ui/JettyUtils.scala | 14 ++++++---- .../spark/deploy/yarn/YarnRMClientImpl.scala | 8 ++++-- .../spark/deploy/yarn/ApplicationMaster.scala | 12 +++----- .../spark/deploy/yarn/YarnRMClient.scala | 4 +-- .../spark/deploy/yarn/YarnRMClientImpl.scala | 28 ++++++++++++++++++- 7 files changed, 53 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 6abf6d930c155..fb8160abc59db 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -66,7 +66,7 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage - case class AddWebUIFilter(filterName:String, filterParams: String, proxyBase :String) + case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 89089e7d6f8a8..59aed6b72fe42 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -275,15 +275,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } // Add filters to the SparkUI - def addWebUIFilter(filterName: String, filterParams: String, proxyBase: String) { + def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) { if (proxyBase != null && proxyBase.nonEmpty) { System.setProperty("spark.ui.proxyBase", proxyBase) } - if (Seq(filterName, filterParams).forall(t => t != null && t.nonEmpty)) { + val hasFilter = (filterName != null && filterName.nonEmpty && + filterParams != null && filterParams.nonEmpty) + if (hasFilter) { logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") conf.set("spark.ui.filters", filterName) - conf.set(s"spark.$filterName.params", filterParams) + filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 6b4689291097f..2a27d49d2de05 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -21,9 +21,7 @@ import java.net.{InetSocketAddress, URL} import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} -import scala.annotation.tailrec import scala.language.implicitConversions -import scala.util.{Failure, Success, Try} import scala.xml.Node import org.eclipse.jetty.server.Server @@ -147,15 +145,19 @@ private[spark] object JettyUtils extends Logging { val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // Get any parameters for each filter - val paramName = "spark." + filter + ".params" - val params = conf.get(paramName, "").split(',').map(_.trim()).toSet - params.foreach { - case param : String => + conf.get("spark." + filter + ".params", "").split(',').map(_.trim()).toSet.foreach { + param: String => if (!param.isEmpty) { val parts = param.split("=") if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) } } + + val prefix = s"spark.$filter.param." + conf.getAll + .filter { case (k, v) => k.length() > prefix.length() && k.startsWith(prefix) } + .foreach { case (k, v) => holder.setInitParameter(k.substring(prefix.length()), v) } + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index acf26505e4cf9..9bd1719cb1808 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -76,8 +76,12 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC resourceManager.finishApplicationMaster(finishReq) } - override def getProxyHostAndPort(conf: YarnConfiguration) = - YarnConfiguration.getProxyHostAndPort(conf) + override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = { + val proxy = YarnConfiguration.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val uriBase = "http://" + proxy + proxyBase + Map("PROXY_HOST" -> parts(0), "PROXY_URI_BASE" -> uriBase) + } override def getMaxRegAttempts(conf: YarnConfiguration) = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b51daeb437516..caceef5d4b5b0 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -368,18 +368,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter() = { - val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" - val proxy = client.getProxyHostAndPort(yarnConf) - val parts = proxy.split(":") val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) - val uriBase = "http://" + proxy + proxyBase - val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val params = client.getAmIpFilterParams(yarnConf, proxyBase) if (isDriver) { System.setProperty("spark.ui.filters", amFilter) - System.setProperty(s"spark.$amFilter.params", params) + params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params, proxyBase) + actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index ed65e56b3e413..943dc56202a37 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -59,8 +59,8 @@ trait YarnRMClient { /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId - /** Returns the RM's proxy host and port. */ - def getProxyHostAndPort(conf: YarnConfiguration): String + /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ + def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String): Map[String, String] /** Returns the maximum number of attempts to register the AM. */ def getMaxRegAttempts(conf: YarnConfiguration): Int diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index 54bc6b14c44ce..b581790e158ac 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -17,8 +17,13 @@ package org.apache.spark.deploy.yarn +import java.util.{List => JList} + import scala.collection.{Map, Set} +import scala.collection.JavaConversions._ +import scala.util._ +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ @@ -69,7 +74,28 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC appAttemptId } - override def getProxyHostAndPort(conf: YarnConfiguration) = WebAppUtils.getProxyHostAndPort(conf) + override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = { + // Figure out which scheme Yarn is using. Note the method seems to have been added after 2.2, + // so not all stable releases have it. + val prefix = Try(classOf[WebAppUtils].getMethod("getHttpSchemePrefix", classOf[Configuration]) + .invoke(null, conf).asInstanceOf[String]).getOrElse("http://") + + // If running a new enough Yarn, use the HA-aware API for retrieving the RM addresses. + try { + val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter", + classOf[Configuration]) + val proxies = method.invoke(null, conf).asInstanceOf[JList[String]] + val hosts = proxies.map { proxy => proxy.split(":")(0) } + val uriBases = proxies.map { proxy => prefix + proxy + proxyBase } + Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + } catch { + case e: NoSuchMethodException => + val proxy = WebAppUtils.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val uriBase = prefix + proxy + proxyBase + Map("PROXY_HOST" -> parts(0), "PROXY_URI_BASE" -> uriBase) + } + } override def getMaxRegAttempts(conf: YarnConfiguration) = conf.getInt(YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) From 1eb8389cb4ad40a405149b16e2719e12367d667a Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 3 Oct 2014 13:26:30 -0700 Subject: [PATCH 188/315] [SPARK-3763] The example of building with sbt should be "sbt assembly" instead of "sbt compile" In building-spark.md, there are some examples for making assembled package with maven but the example for building with sbt is only about for compiling. Author: Kousuke Saruta Closes #2627 from sarutak/SPARK-3763 and squashes the following commits: fadb990 [Kousuke Saruta] Modified the example to build with sbt in building-spark.md --- docs/building-spark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 2378092d4a1a8..901c157162fee 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -169,7 +169,7 @@ compilation. More advanced developers may wish to use SBT. The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables can be set to control the SBT build. For example: - sbt/sbt -Pyarn -Phadoop-2.3 compile + sbt/sbt -Pyarn -Phadoop-2.3 assembly # Speeding up Compilation with Zinc From 79e45c9323455a51f25ed9acd0edd8682b4bbb88 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 3 Oct 2014 13:48:56 -0700 Subject: [PATCH 189/315] [SPARK-3377] [SPARK-3610] Metrics can be accidentally aggregated / History server log name should not be based on user input This PR is another solution for #2250 I'm using codahale base MetricsSystem of Spark with JMX or Graphite, and I saw following 2 problems. (1) When applications which have same spark.app.name run on cluster at the same time, some metrics names are mixed. For instance, if 2+ application is running on the cluster at the same time, each application emits the same named metric like "SparkPi.DAGScheduler.stage.failedStages" and Graphite cannot distinguish the metrics is for which application. (2) When 2+ executors run on the same machine, JVM metrics of each executors are mixed. For instance, 2+ executors running on the same node can emit the same named metric "jvm.memory" and Graphite cannot distinguish the metrics is from which application. And there is an similar issue. The directory for event logs is named using application name. Application name is defined by user and the name can includes illegal character for path names. Further more, the directory name consists of application name and System.currentTimeMillis even though each application has unique Application ID so if we run jobs which have same name, it's difficult to identify which directory is for which application. Closes #2250 Closes #1067 Author: Kousuke Saruta Closes #2432 from sarutak/metrics-structure-improvement2 and squashes the following commits: 3288b2b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 39169e4 [Kousuke Saruta] Fixed style 6570494 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 817e4f0 [Kousuke Saruta] Simplified MetricsSystem#buildRegistryName 67fa5eb [Kousuke Saruta] Unified MetricsSystem#registerSources and registerSinks in start 10be654 [Kousuke Saruta] Fixed style. 990c078 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 f0c7fba [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 59cc2cd [Kousuke Saruta] Modified SparkContextSchedulerCreationSuite f9b6fb3 [Kousuke Saruta] Modified style. 2cf8a0f [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 389090d [Kousuke Saruta] Replaced taskScheduler.applicationId() with getApplicationId in SparkContext#postApplicationStart ff45c89 [Kousuke Saruta] Added some test cases to MetricsSystemSuite 69c46a6 [Kousuke Saruta] Added warning logging logic to MetricsSystem#buildRegistryName 5cca0d2 [Kousuke Saruta] Added Javadoc comment to SparkContext#getApplicationId 16a9f01 [Kousuke Saruta] Added data types to be returned to some methods 6434b06 [Kousuke Saruta] Reverted changes related to ApplicationId 0413b90 [Kousuke Saruta] Deleted ApplicationId.java and ApplicationIdSuite.java a42300c [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 0fc1b09 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 42bea55 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 248935d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 f6af132 [Kousuke Saruta] Modified SchedulerBackend and TaskScheduler to return System.currentTimeMillis as an unique Application Id 1b8b53e [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 97cb85c [Kousuke Saruta] Modified confliction of MimExcludes 2cdd009 [Kousuke Saruta] Modified defailt implementation of applicationId 9aadb0b [Kousuke Saruta] Modified NetworkReceiverSuite to ensure "executor.start()" is finished in test "network receiver life cycle" 3011efc [Kousuke Saruta] Added ApplicationIdSuite.scala d009c55 [Kousuke Saruta] Modified ApplicationId#equals to compare appIds dfc83fd [Kousuke Saruta] Modified ApplicationId to implement Serializable 9ff4851 [Kousuke Saruta] Modified MimaExcludes.scala to ignore createTaskScheduler method in SparkContext 4567ffc [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 6a91b14 [Kousuke Saruta] Modified SparkContextSchedulerCreationSuite, ExecutorRunnerTest and EventLoggingListenerSuite 0325caf [Kousuke Saruta] Added ApplicationId.scala 0a2fc14 [Kousuke Saruta] Modified style eabda80 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 0f890e6 [Kousuke Saruta] Modified SparkDeploySchedulerBackend and Master to pass baseLogDir instead f eventLogDir bcf25bf [Kousuke Saruta] Modified directory name for EventLogs 28d4d93 [Kousuke Saruta] Modified SparkContext and EventLoggingListener so that the directory for EventLogs is named same for Application ID 203634e [Kousuke Saruta] Modified comment in SchedulerBackend#applicationId and TaskScheduler#applicationId 424fea4 [Kousuke Saruta] Modified the subclasses of TaskScheduler and SchedulerBackend so that they can return non-optional Unique Application ID b311806 [Kousuke Saruta] Swapped last 2 arguments passed to CoarseGrainedExecutorBackend 8a2b6ec [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 086ee25 [Kousuke Saruta] Merge branch 'metrics-structure-improvement2' of github.com:sarutak/spark into metrics-structure-improvement2 e705386 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 36d2f7a [Kousuke Saruta] Added warning message for the situation we cannot get application id for the prefix for the name of metrics eea6e19 [Kousuke Saruta] Modified CoarseGrainedMesosSchedulerBackend and MesosSchedulerBackend so that we can get Application ID c229fbe [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 e719c39 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 4a93c7f [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 4776f9e [Kousuke Saruta] Modified MetricsSystemSuite.scala efcb6e1 [Kousuke Saruta] Modified to add application id to metrics name 2ec848a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 3ea7896 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement ead8966 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 08e627e [Kousuke Saruta] Revert "tmp" 7b67f5a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 45bd33d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 93e263a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 848819c [Kousuke Saruta] Merge branch 'metrics-structure-improvement' of github.com:sarutak/spark into metrics-structure-improvement 912a637 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement e4a4593 [Kousuke Saruta] tmp 3e098d8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 4603a39 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement fa7175b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 15f88a3 [Kousuke Saruta] Modified MetricsSystem#buildRegistryName because conf.get does not return null when correspondin entry is absent 6f7dcd4 [Kousuke Saruta] Modified constructor of DAGSchedulerSource and BlockManagerSource because the instance of SparkContext is no longer used 6fc5560 [Kousuke Saruta] Modified sourceName of ExecutorSource, DAGSchedulerSource and BlockManagerSource 4e057c9 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 85ffc02 [Kousuke Saruta] Revert "Modified sourceName of ExecutorSource, DAGSchedulerSource and BlockManagerSource" 868e326 [Kousuke Saruta] Modified MetricsSystem to set registry name with unique application-id and driver/executor-id 71609f5 [Kousuke Saruta] Modified sourceName of ExecutorSource, DAGSchedulerSource and BlockManagerSource 55debab [Kousuke Saruta] Modified SparkContext and Executor to set spark.executor.id to identifiers 4180993 [Kousuke Saruta] Modified SparkContext to retain spark.unique.app.name property in SparkConf --- .../scala/org/apache/spark/SparkContext.scala | 52 ++++--- .../scala/org/apache/spark/SparkEnv.scala | 8 +- .../apache/spark/deploy/master/Master.scala | 12 +- .../CoarseGrainedExecutorBackend.scala | 16 ++- .../org/apache/spark/executor/Executor.scala | 1 + .../spark/executor/ExecutorSource.scala | 3 +- .../spark/executor/MesosExecutorBackend.scala | 3 +- .../apache/spark/metrics/MetricsSystem.scala | 40 +++++- .../spark/scheduler/DAGSchedulerSource.scala | 4 +- .../scheduler/EventLoggingListener.scala | 33 +++-- .../spark/scheduler/SchedulerBackend.scala | 8 +- .../spark/scheduler/TaskScheduler.scala | 8 +- .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 9 +- .../mesos/CoarseMesosSchedulerBackend.scala | 11 +- .../cluster/mesos/MesosSchedulerBackend.scala | 13 +- .../spark/scheduler/local/LocalBackend.scala | 3 + .../spark/storage/BlockManagerSource.scala | 4 +- .../spark/metrics/MetricsSystemSuite.scala | 128 +++++++++++++++++- .../scheduler/EventLoggingListenerSuite.scala | 14 +- .../spark/scheduler/ReplayListenerSuite.scala | 3 +- .../streaming/NetworkReceiverSuite.scala | 14 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 +- .../deploy/yarn/ExecutorRunnableUtil.scala | 2 + .../spark/deploy/yarn/YarnAllocator.scala | 2 + .../cluster/YarnClientSchedulerBackend.scala | 6 +- .../cluster/YarnClusterSchedulerBackend.scala | 9 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 +- .../deploy/yarn/YarnAllocationHandler.scala | 2 +- 29 files changed, 331 insertions(+), 85 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 979d178c35969..97109b9f41b60 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -187,6 +187,15 @@ class SparkContext(config: SparkConf) extends Logging { val master = conf.get("spark.master") val appName = conf.get("spark.app.name") + private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) + private[spark] val eventLogDir: Option[String] = { + if (isEventLogEnabled) { + Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) + } else { + None + } + } + // Generate the random name for a temp folder in Tachyon // Add a timestamp as the suffix here to make it more safe val tachyonFolderName = "spark-" + randomUUID.toString() @@ -200,6 +209,7 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] val listenerBus = new LiveListenerBus // Create the Spark execution environment (cache, map output tracker, etc) + conf.set("spark.executor.id", "driver") private[spark] val env = SparkEnv.create( conf, "", @@ -232,19 +242,6 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (conf.getBoolean("spark.eventLog.enabled", false)) { - val logger = new EventLoggingListener(appName, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } - - // At this point, all relevant SparkListeners have been registered, so begin releasing events - listenerBus.start() - val startTime = System.currentTimeMillis() // Add each JAR given through the constructor @@ -309,6 +306,29 @@ class SparkContext(config: SparkConf) extends Logging { // constructor taskScheduler.start() + val applicationId: String = taskScheduler.applicationId() + conf.set("spark.app.id", applicationId) + + val metricsSystem = env.metricsSystem + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + + // Optionally log Spark events + private[spark] val eventLogger: Option[EventLoggingListener] = { + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else None + } + + // At this point, all relevant SparkListeners have been registered, so begin releasing events + listenerBus.start() + private[spark] val cleaner: Option[ContextCleaner] = { if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { Some(new ContextCleaner(this)) @@ -411,8 +431,8 @@ class SparkContext(config: SparkConf) extends Logging { // Post init taskScheduler.postStartHook() - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) + private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) + private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) private def initDriverMetrics() { SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) @@ -1278,7 +1298,7 @@ class SparkContext(config: SparkConf) extends Logging { private def postApplicationStart() { // Note: this code assumes that the task scheduler has been initialized and has contacted // the cluster manager to get an application ID (in case the cluster manager provides one). - listenerBus.post(SparkListenerApplicationStart(appName, taskScheduler.applicationId(), + listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId), startTime, sparkUser)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 009ed64775844..72cac42cd2b2b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -259,11 +259,15 @@ object SparkEnv extends Logging { } val metricsSystem = if (isDriver) { + // Don't start metrics system right now for Driver. + // We need to wait for the task scheduler to give us an app ID. + // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { - MetricsSystem.createMetricsSystem("executor", conf, securityManager) + val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) + ms.start() + ms } - metricsSystem.start() // Set the sparkFiles directory, used when downloading dependencies. In local mode, // this is a temporary directory; in distributed mode, this is the executor's current working diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 432b552c58cd8..f98b531316a3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -33,8 +33,8 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, - SparkHadoopUtil} +import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, + ExecutorState, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState @@ -693,16 +693,18 @@ private[spark] class Master( app.desc.appUiUrl = notFoundBasePath return false } - val fileSystem = Utils.getHadoopFileSystem(eventLogDir, + + val appEventLogDir = EventLoggingListener.getLogDirPath(eventLogDir, app.id) + val fileSystem = Utils.getHadoopFileSystem(appEventLogDir, SparkHadoopUtil.get.newConfiguration(conf)) - val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem) + val eventLogInfo = EventLoggingListener.parseLoggingInfo(appEventLogDir, fileSystem) val eventLogPaths = eventLogInfo.logPaths val compressionCodec = eventLogInfo.compressionCodec if (eventLogPaths.isEmpty) { // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" - var msg = s"No event logs found for application $appName in $eventLogDir." + var msg = s"No event logs found for application $appName in $appEventLogDir." logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 13af5b6f5812d..06061edfc0844 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -106,6 +106,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { executorId: String, hostname: String, cores: Int, + appId: String, workerUrl: Option[String]) { SignalLogger.register(log) @@ -122,7 +123,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val driver = fetcher.actorSelection(driverUrl) val timeout = AkkaUtils.askTimeout(executorConf) val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] + val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create a new ActorSystem using driver's Spark properties to run the backend. @@ -144,16 +146,16 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { def main(args: Array[String]) { args.length match { - case x if x < 4 => + case x if x < 5 => System.err.println( // Worker url is used in spark standalone mode to enforce fate-sharing with worker "Usage: CoarseGrainedExecutorBackend " + - " []") + " [] ") System.exit(1) - case 4 => - run(args(0), args(1), args(2), args(3).toInt, None) - case x if x > 4 => - run(args(0), args(1), args(2), args(3).toInt, Some(args(4))) + case 5 => + run(args(0), args(1), args(2), args(3).toInt, args(4), None) + case x if x > 5 => + run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5))) } } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d7211ae465902..9bbfcdc4a0b6e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -74,6 +74,7 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) + conf.set("spark.executor.id", "executor." + executorId) private val env = { if (!isLocal) { val _env = SparkEnv.create(conf, executorId, slaveHostname, 0, diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index d6721586566c2..c4d73622c4727 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -37,8 +37,7 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) override val metricRegistry = new MetricRegistry() - // TODO: It would be nice to pass the application name here - override val sourceName = "executor.%s".format(executorId) + override val sourceName = "executor" // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index a42c8b43bbf7f..bca0b152268ad 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -52,7 +52,8 @@ private[spark] class MesosExecutorBackend slaveInfo: SlaveInfo) { logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver - val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) + val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ + Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) executor = new Executor( executorInfo.getExecutorId.getValue, slaveInfo.getHostname, diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index fd316a89a1a10..5dd67b0cbf683 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -83,10 +83,10 @@ private[spark] class MetricsSystem private ( def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array()) metricsConfig.initialize() - registerSources() - registerSinks() def start() { + registerSources() + registerSinks() sinks.foreach(_.start) } @@ -98,10 +98,39 @@ private[spark] class MetricsSystem private ( sinks.foreach(_.report()) } + /** + * Build a name that uniquely identifies each metric source. + * The name is structured as follows: ... + * If either ID is not available, this defaults to just using . + * + * @param source Metric source to be named by this method. + * @return An unique metric name for each combination of + * application, executor/driver and metric source. + */ + def buildRegistryName(source: Source): String = { + val appId = conf.getOption("spark.app.id") + val executorId = conf.getOption("spark.executor.id") + val defaultName = MetricRegistry.name(source.sourceName) + + if (instance == "driver" || instance == "executor") { + if (appId.isDefined && executorId.isDefined) { + MetricRegistry.name(appId.get, executorId.get, source.sourceName) + } else { + // Only Driver and Executor are set spark.app.id and spark.executor.id. + // For instance, Master and Worker are not related to a specific application. + val warningMsg = s"Using default name $defaultName for source because %s is not set." + if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } + if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } + defaultName + } + } else { defaultName } + } + def registerSource(source: Source) { sources += source try { - registry.register(source.sourceName, source.metricRegistry) + val regName = buildRegistryName(source) + registry.register(regName, source.metricRegistry) } catch { case e: IllegalArgumentException => logInfo("Metrics already registered", e) } @@ -109,8 +138,9 @@ private[spark] class MetricsSystem private ( def removeSource(source: Source) { sources -= source + val regName = buildRegistryName(source) registry.removeMatching(new MetricFilter { - def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName) + def matches(name: String, metric: Metric): Boolean = name.startsWith(regName) }) } @@ -125,7 +155,7 @@ private[spark] class MetricsSystem private ( val source = Class.forName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { - case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e) + case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 94944399b134a..12668b6c0988e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry} import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext) +private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() - override val sourceName = "%s.DAGScheduler".format(sc.appName) + override val sourceName = "DAGScheduler" metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.failedStages.size diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 64b32ae0edaac..100c9ba9b7809 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -43,38 +43,29 @@ import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} * spark.eventLog.buffer.kb - Buffer size to use when writing to output streams */ private[spark] class EventLoggingListener( - appName: String, + appId: String, + logBaseDir: String, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appName: String, sparkConf: SparkConf) = - this(appName, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) + def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") - private val name = appName.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_") - .toLowerCase + "-" + System.currentTimeMillis - val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") - + val logDir = EventLoggingListener.getLogDirPath(logBaseDir, appId) + val logDirName: String = logDir.split("/").last protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS)) // For testing. Keep track of all JSON serialized events that have been logged. private[scheduler] val loggedEvents = new ArrayBuffer[JValue] - /** - * Return only the unique application directory without the base directory. - */ - def getApplicationLogDir(): String = { - name - } - /** * Begin logging events. * If compression is used, log a file that indicates which compression library is used. @@ -184,6 +175,18 @@ private[spark] object EventLoggingListener extends Logging { } else "" } + /** + * Return a file-system-safe path to the log directory for the given application. + * + * @param logBaseDir A base directory for the path to the log directory for given application. + * @param appId A unique app ID. + * @return A path which consists of file-system-safe characters. + */ + def getLogDirPath(logBaseDir: String, appId: String): String = { + val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase + Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") + } + /** * Parse the event logging information associated with the logs in the given directory. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index a0be8307eff27..992c477493d8e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -23,6 +23,8 @@ package org.apache.spark.scheduler * machines become available and can launch tasks on them. */ private[spark] trait SchedulerBackend { + private val appId = "spark-application-" + System.currentTimeMillis + def start(): Unit def stop(): Unit def reviveOffers(): Unit @@ -33,10 +35,10 @@ private[spark] trait SchedulerBackend { def isReady(): Boolean = true /** - * The application ID associated with the job, if any. + * Get an application ID associated with the job. * - * @return The application ID, or None if the backend does not provide an ID. + * @return An application ID */ - def applicationId(): Option[String] = None + def applicationId(): String = appId } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 1c1ce666eab0f..a129a434c9a1a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -31,6 +31,8 @@ import org.apache.spark.storage.BlockManagerId */ private[spark] trait TaskScheduler { + private val appId = "spark-application-" + System.currentTimeMillis + def rootPool: Pool def schedulingMode: SchedulingMode @@ -66,10 +68,10 @@ private[spark] trait TaskScheduler { blockManagerId: BlockManagerId): Boolean /** - * The application ID associated with the job, if any. + * Get an application ID associated with the job. * - * @return The application ID, or None if the backend does not provide an ID. + * @return An application ID */ - def applicationId(): Option[String] = None + def applicationId(): String = appId } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 633e892554c50..4dc550413c13c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -492,7 +492,7 @@ private[spark] class TaskSchedulerImpl( } } - override def applicationId(): Option[String] = backend.applicationId() + override def applicationId(): String = backend.applicationId() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 5c5ecc8434d78..ed209d195ec9d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -68,9 +68,8 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - val eventLogDir = sc.eventLogger.map(_.logDir) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, eventLogDir) + appUIAddress, sc.eventLogDir) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() @@ -129,7 +128,11 @@ private[spark] class SparkDeploySchedulerBackend( totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } - override def applicationId(): Option[String] = Option(appId) + override def applicationId(): String = + Option(appId).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } private def waitForRegistration() = { registrationLock.synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 3161f1ee9fa8a..90828578cd88f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -76,6 +76,8 @@ private[spark] class CoarseMesosSchedulerBackend( var nextMesosTaskId = 0 + @volatile var appId: String = _ + def newMesosTaskId(): Int = { val id = nextMesosTaskId nextMesosTaskId += 1 @@ -167,7 +169,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - logInfo("Registered as framework ID " + frameworkId.getValue) + appId = frameworkId.getValue + logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() @@ -313,4 +316,10 @@ private[spark] class CoarseMesosSchedulerBackend( slaveLost(d, s) } + override def applicationId(): String = + Option(appId).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 4c49aa074ebc0..b11786368e661 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -30,7 +30,7 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils /** @@ -62,6 +62,8 @@ private[spark] class MesosSchedulerBackend( var classLoader: ClassLoader = null + @volatile var appId: String = _ + override def start() { synchronized { classLoader = Thread.currentThread.getContextClassLoader @@ -177,7 +179,8 @@ private[spark] class MesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { val oldClassLoader = setClassLoader() try { - logInfo("Registered as framework ID " + frameworkId.getValue) + appId = frameworkId.getValue + logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() @@ -372,4 +375,10 @@ private[spark] class MesosSchedulerBackend( // TODO: query Mesos for number of cores override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + override def applicationId(): String = + Option(appId).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 9ea25c2bc7090..58b78f041cd85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -88,6 +88,7 @@ private[spark] class LocalActor( private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) extends SchedulerBackend with ExecutorBackend { + private val appId = "local-" + System.currentTimeMillis var localActor: ActorRef = null override def start() { @@ -115,4 +116,6 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! StatusUpdate(taskId, state, serializedData) } + override def applicationId(): String = appId + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 49fea6d9e2a76..8569c6f3cbbc3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry} import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source -private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext) +private[spark] class BlockManagerSource(val blockManager: BlockManager) extends Source { override val metricRegistry = new MetricRegistry() - override val sourceName = "%s.BlockManager".format(sc.appName) + override val sourceName = "BlockManager" metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index e42b181194727..3925f0ccbdbf0 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.metrics -import org.apache.spark.metrics.source.Source import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.MasterSource +import org.apache.spark.metrics.source.Source -import scala.collection.mutable.ArrayBuffer +import com.codahale.metrics.MetricRegistry +import scala.collection.mutable.ArrayBuffer class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ @@ -39,6 +40,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod test("MetricsSystem with default config") { val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) + metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) @@ -49,6 +51,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod test("MetricsSystem with sources add") { val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) + metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) @@ -60,4 +63,125 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod metricsSystem.registerSource(source) assert(metricsSystem.invokePrivate(sources()).length === 1) } + + test("MetricsSystem with Driver instance") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + val executorId = "driver" + conf.set("spark.app.id", appId) + conf.set("spark.executor.id", executorId) + + val instanceName = "driver" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === s"$appId.$executorId.${source.sourceName}") + } + + test("MetricsSystem with Driver instance and spark.app.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val executorId = "driver" + conf.set("spark.executor.id", executorId) + + val instanceName = "driver" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with Driver instance and spark.executor.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + conf.set("spark.app.id", appId) + + val instanceName = "driver" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with Executor instance") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + val executorId = "executor.1" + conf.set("spark.app.id", appId) + conf.set("spark.executor.id", executorId) + + val instanceName = "executor" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === s"$appId.$executorId.${source.sourceName}") + } + + test("MetricsSystem with Executor instance and spark.app.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val executorId = "executor.1" + conf.set("spark.executor.id", executorId) + + val instanceName = "executor" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with Executor instance and spark.executor.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + conf.set("spark.app.id", appId) + + val instanceName = "executor" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with instance which is neither Driver nor Executor") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + val executorId = "dummyExecutorId" + conf.set("spark.app.id", appId) + conf.set("spark.executor.id", executorId) + + val instanceName = "testInstance" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + + // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name. + assert(metricName != s"$appId.$executorId.${source.sourceName}") + assert(metricName === source.sourceName) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index e5315bc93e217..3efa85431876b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -169,7 +169,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { // Verify logging directory exists val conf = getLoggingConf(logDirPath, compressionCodec) - val eventLogger = new EventLoggingListener("test", conf) + val logBaseDir = conf.get("spark.eventLog.dir") + val appId = EventLoggingListenerSuite.getUniqueApplicationId + val eventLogger = new EventLoggingListener(appId, logBaseDir, conf) eventLogger.start() val logPath = new Path(eventLogger.logDir) assert(fileSystem.exists(logPath)) @@ -209,7 +211,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { // Verify that all information is correctly parsed before stop() val conf = getLoggingConf(logDirPath, compressionCodec) - val eventLogger = new EventLoggingListener("test", conf) + val logBaseDir = conf.get("spark.eventLog.dir") + val appId = EventLoggingListenerSuite.getUniqueApplicationId + val eventLogger = new EventLoggingListener(appId, logBaseDir, conf) eventLogger.start() var eventLoggingInfo = EventLoggingListener.parseLoggingInfo(eventLogger.logDir, fileSystem) assertInfoCorrect(eventLoggingInfo, loggerStopped = false) @@ -228,7 +232,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { */ private def testEventLogging(compressionCodec: Option[String] = None) { val conf = getLoggingConf(logDirPath, compressionCodec) - val eventLogger = new EventLoggingListener("test", conf) + val logBaseDir = conf.get("spark.eventLog.dir") + val appId = EventLoggingListenerSuite.getUniqueApplicationId + val eventLogger = new EventLoggingListener(appId, logBaseDir, conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -408,4 +414,6 @@ object EventLoggingListenerSuite { } conf } + + def getUniqueApplicationId = "test-" + System.currentTimeMillis } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 7ab351d1b4d24..48114feee6233 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -155,7 +155,8 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * This child listener inherits only the event buffering functionality, but does not actually * log the events. */ - private class EventMonster(conf: SparkConf) extends EventLoggingListener("test", conf) { + private class EventMonster(conf: SparkConf) + extends EventLoggingListener("test", "testdir", conf) { logger.close() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index 99c8d13231aac..eb6e88cf5520d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.nio.ByteBuffer +import java.util.concurrent.Semaphore import scala.collection.mutable.ArrayBuffer @@ -36,6 +37,7 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { val receiver = new FakeReceiver val executor = new FakeReceiverSupervisor(receiver) + val executorStarted = new Semaphore(0) assert(executor.isAllEmpty) @@ -43,6 +45,7 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { val executingThread = new Thread() { override def run() { executor.start() + executorStarted.release(1) executor.awaitTermination() } } @@ -57,6 +60,9 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { } } + // Ensure executor is started + executorStarted.acquire() + // Verify that receiver was started assert(receiver.onStartCalled) assert(executor.isReceiverStarted) @@ -186,10 +192,10 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. */ class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - var otherThread: Thread = null - var receiving = false - var onStartCalled = false - var onStopCalled = false + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false def onStart() { otherThread = new Thread() { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 10cbeb8b94325..229b7a09f456b 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -47,6 +47,7 @@ class ExecutorRunnable( hostname: String, executorMemory: Int, executorCores: Int, + appAttemptId: String, securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { @@ -83,7 +84,7 @@ class ExecutorRunnable( ctx.setContainerTokens(ByteBuffer.wrap(dob.getData())) val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, - localResources) + appAttemptId, localResources) logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index d7a7175d5e578..5cb4753de2e84 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -43,6 +43,7 @@ trait ExecutorRunnableUtil extends Logging { hostname: String, executorMemory: Int, executorCores: Int, + appId: String, localResources: HashMap[String, LocalResource]): List[String] = { // Extra options for the JVM val javaOpts = ListBuffer[String]() @@ -114,6 +115,7 @@ trait ExecutorRunnableUtil extends Logging { slaveId.toString, hostname.toString, executorCores.toString, + appId, "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4f4f1d2aaaade..e1af8d5a74cb1 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -57,6 +57,7 @@ object AllocationType extends Enumeration { private[yarn] abstract class YarnAllocator( conf: Configuration, sparkConf: SparkConf, + appAttemptId: ApplicationAttemptId, args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) @@ -295,6 +296,7 @@ private[yarn] abstract class YarnAllocator( executorHostname, executorMemory, executorCores, + appAttemptId.getApplicationId.toString, securityMgr) launcherPool.execute(executorRunnable) } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 200a30899290b..6bb4b82316ad4 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -155,6 +155,10 @@ private[spark] class YarnClientSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } - override def applicationId(): Option[String] = Option(appId).map(_.toString()) + override def applicationId(): String = + Option(appId).map(_.toString).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 39436d0999663..3a186cfeb4eeb 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -48,6 +48,13 @@ private[spark] class YarnClusterSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } - override def applicationId(): Option[String] = sc.getConf.getOption("spark.yarn.app.id") + override def applicationId(): String = + // In YARN Cluster mode, spark.yarn.app.id is expect to be set + // before user application is launched. + // So, if spark.yarn.app.id is not set, it is something wrong. + sc.getConf.getOption("spark.yarn.app.id").getOrElse { + logError("Application ID is not set.") + super.applicationId + } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 833be12982e71..0b5a92d87d722 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -47,6 +47,7 @@ class ExecutorRunnable( hostname: String, executorMemory: Int, executorCores: Int, + appId: String, securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { @@ -80,7 +81,7 @@ class ExecutorRunnable( ctx.setTokens(ByteBuffer.wrap(dob.getData())) val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, - localResources) + appId, localResources) logInfo(s"Setting up executor with environment: $env") logInfo("Setting up executor with commands: " + commands) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index e44a8db41b97e..2bbf5d7db8668 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -41,7 +41,7 @@ private[yarn] class YarnAllocationHandler( args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) - extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { + extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) { override protected def releaseContainer(container: Container) = { amClient.releaseAssignedContainer(container.getId()) From cf1d32e3e1071829b152d4b597bf0a0d7a5629a2 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 3 Oct 2014 14:22:11 -0700 Subject: [PATCH 190/315] [SPARK-1860] More conservative app directory cleanup. First contribution to the project, so apologize for any significant errors. This PR addresses [SPARK-1860]. The application directories are now cleaned up in a more conservative manner. Previously, app-* directories were cleaned up if the directory's timestamp was older than a given time. However, the timestamp on a directory does not reflect the modification times of the files in that directory. Therefore, app-* directories were wiped out even if the files inside them were created recently and possibly being used by Executor tasks. The solution is to change the cleanup logic to inspect all files within the app-* directory and only eliminate the app-* directory if all files in the directory are stale. Author: mcheah Closes #2609 from mccheah/worker-better-app-dir-cleanup and squashes the following commits: 87b5d03 [mcheah] [SPARK-1860] Using more string interpolation. Better error logging. 802473e [mcheah] [SPARK-1860] Cleaning up the logs generated when cleaning directories. e0a1f2e [mcheah] [SPARK-1860] Fixing broken unit test. 77a9de0 [mcheah] [SPARK-1860] More conservative app directory cleanup. --- .../spark/deploy/worker/ExecutorRunner.scala | 8 +--- .../apache/spark/deploy/worker/Worker.scala | 37 ++++++++++++++++--- .../scala/org/apache/spark/util/Utils.scala | 21 +++++++---- .../org/apache/spark/util/UtilsSuite.scala | 23 +++++++++--- 4 files changed, 62 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 00a43673e5cd3..71650cd773bcf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -42,7 +42,7 @@ private[spark] class ExecutorRunner( val workerId: String, val host: String, val sparkHome: File, - val workDir: File, + val executorDir: File, val workerUrl: String, val conf: SparkConf, var state: ExecutorState.Value) @@ -130,12 +130,6 @@ private[spark] class ExecutorRunner( */ def fetchAndRunExecutor() { try { - // Create the executor's working directory - val executorDir = new File(workDir, appId + "/" + execId) - if (!executorDir.mkdirs()) { - throw new IOException("Failed to create directory " + executorDir) - } - // Launch the process val command = getCommandSeq logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0c454e4138c96..3b13f43a1868c 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -18,15 +18,18 @@ package org.apache.spark.deploy.worker import java.io.File +import java.io.IOException import java.text.SimpleDateFormat import java.util.Date +import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.concurrent.duration._ import scala.language.postfixOps import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.commons.io.FileUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} @@ -191,6 +194,7 @@ private[spark] class Worker( changeMaster(masterUrl, masterWebUiUrl) context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) if (CLEANUP_ENABLED) { + logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) } @@ -201,10 +205,23 @@ private[spark] class Worker( case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor val cleanupFuture = concurrent.future { - logInfo("Cleaning up oldest application directories in " + workDir + " ...") - Utils.findOldFiles(workDir, APP_DATA_RETENTION_SECS) - .foreach(Utils.deleteRecursively) + val appDirs = workDir.listFiles() + if (appDirs == null) { + throw new IOException("ERROR: Failed to list files in " + appDirs) + } + appDirs.filter { dir => + // the directory is used by an application - check that the application is not running + // when cleaning up + val appIdFromDir = dir.getName + val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + dir.isDirectory && !isAppStillRunning && + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + }.foreach { dir => + logInfo(s"Removing directory: ${dir.getPath}") + Utils.deleteRecursively(dir) + } } + cleanupFuture onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) @@ -233,8 +250,15 @@ private[spark] class Worker( } else { try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) + + // Create the executor's working directory + val executorDir = new File(workDir, appId + "/" + execId) + if (!executorDir.mkdirs()) { + throw new IOException("Failed to create directory " + executorDir) + } + val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.LOADING) + self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -242,12 +266,13 @@ private[spark] class Worker( master ! ExecutorStateChanged(appId, execId, manager.state, None, None) } catch { case e: Exception => { - logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) + logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) if (executors.contains(appId + "/" + execId)) { executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) + master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None) } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9399ddab76331..a67124140f9da 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -35,6 +35,8 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.log4j.PropertyConfigurator @@ -705,17 +707,20 @@ private[spark] object Utils extends Logging { } /** - * Finds all the files in a directory whose last modified time is older than cutoff seconds. - * @param dir must be the path to a directory, or IllegalArgumentException is thrown - * @param cutoff measured in seconds. Files older than this are returned. + * Determines if a directory contains any files newer than cutoff seconds. + * + * @param dir must be the path to a directory, or IllegalArgumentException is thrown + * @param cutoff measured in seconds. Returns true if there are any files in dir newer than this. */ - def findOldFiles(dir: File, cutoff: Long): Seq[File] = { + def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { val currentTimeMillis = System.currentTimeMillis - if (dir.isDirectory) { - val files = listFilesSafely(dir) - files.filter { file => file.lastModified < (currentTimeMillis - cutoff * 1000) } + if (!dir.isDirectory) { + throw new IllegalArgumentException (dir + " is not a directory!") } else { - throw new IllegalArgumentException(dir + " is not a directory!") + val files = FileUtils.listFilesAndDirs(dir, TrueFileFilter.TRUE, TrueFileFilter.TRUE) + val cutoffTimeInMillis = (currentTimeMillis - (cutoff * 1000)) + val newFiles = files.filter { _.lastModified > cutoffTimeInMillis } + newFiles.nonEmpty } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 70d423ba8a04d..e63d9d085e385 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -189,17 +189,28 @@ class UtilsSuite extends FunSuite { assert(Utils.getIteratorSize(iterator) === 5L) } - test("findOldFiles") { + test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories val child2: File = Utils.createTempDir(parent.getCanonicalPath) - // set the last modified time of child1 to 10 secs old - child1.setLastModified(System.currentTimeMillis() - (1000 * 10)) + val child3: File = Utils.createTempDir(child1.getCanonicalPath) + // set the last modified time of child1 to 30 secs old + child1.setLastModified(System.currentTimeMillis() - (1000 * 30)) - val result = Utils.findOldFiles(parent, 5) // find files older than 5 secs - assert(result.size.equals(1)) - assert(result(0).getCanonicalPath.equals(child1.getCanonicalPath)) + // although child1 is old, child2 is still new so return true + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + child2.setLastModified(System.currentTimeMillis - (1000 * 30)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + parent.setLastModified(System.currentTimeMillis - (1000 * 30)) + // although parent and its immediate children are new, child3 is still old + // we expect a full recursive search for new files. + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + child3.setLastModified(System.currentTimeMillis - (1000 * 30)) + assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) } test("resolveURI") { From 32fad4233f353814496c84e15ba64326730b7ae7 Mon Sep 17 00:00:00 2001 From: Brenden Matthews Date: Sun, 5 Oct 2014 09:49:24 -0700 Subject: [PATCH 191/315] [SPARK-3597][Mesos] Implement `killTask`. The MesosSchedulerBackend did not previously implement `killTask`, resulting in an exception. Author: Brenden Matthews Closes #2453 from brndnmtthws/implement-killtask and squashes the following commits: 23ddcdc [Brenden Matthews] [SPARK-3597][Mesos] Implement `killTask`. --- .../scheduler/cluster/mesos/MesosSchedulerBackend.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index b11786368e661..e0f2fd622f54c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -372,6 +372,13 @@ private[spark] class MesosSchedulerBackend( recordSlaveLost(d, slaveId, ExecutorExited(status)) } + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + driver.killTask( + TaskID.newBuilder() + .setValue(taskId.toString).build() + ) + } + // TODO: query Mesos for number of cores override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) From a7c73130f1b6b0b8b19a7b0a0de5c713b673cd7b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 5 Oct 2014 09:55:17 -0700 Subject: [PATCH 192/315] SPARK-1656: Fix potential resource leaks JIRA: https://issues.apache.org/jira/browse/SPARK-1656 Author: zsxwing Closes #577 from zsxwing/SPARK-1656 and squashes the following commits: c431095 [zsxwing] Add a comment and fix the code style 2de96e5 [zsxwing] Make sure file will be deleted if exception happens 28b90dc [zsxwing] Update to follow the code style 4521d6e [zsxwing] Merge branch 'master' into SPARK-1656 afc3383 [zsxwing] Update to follow the code style 071fdd1 [zsxwing] SPARK-1656: Fix potential resource leaks --- .../spark/broadcast/HttpBroadcast.scala | 25 +++++++++++-------- .../master/FileSystemPersistenceEngine.scala | 14 ++++++++--- .../org/apache/spark/storage/DiskStore.scala | 16 +++++++++++- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 942dc7d7eac87..4cd4f4f96fd16 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -163,18 +163,23 @@ private[broadcast] object HttpBroadcast extends Logging { private def write(id: Long, value: Any) { val file = getFile(id) - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(new FileOutputStream(file)) - } else { - new BufferedOutputStream(new FileOutputStream(file), bufferSize) + val fileOutputStream = new FileOutputStream(file) + try { + val out: OutputStream = { + if (compress) { + compressionCodec.compressedOutputStream(fileOutputStream) + } else { + new BufferedOutputStream(fileOutputStream, bufferSize) + } } + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject(value) + serOut.close() + files += file + } finally { + fileOutputStream.close() } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() - files += file } private def read[T: ClassTag](id: Long): T = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index aa85aa060d9c1..08a99bbe68578 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -83,15 +83,21 @@ private[spark] class FileSystemPersistenceEngine( val serialized = serializer.toBinary(value) val out = new FileOutputStream(file) - out.write(serialized) - out.close() + try { + out.write(serialized) + } finally { + out.close() + } } def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) - dis.readFully(fileData) - dis.close() + try { + dis.readFully(fileData) + } finally { + dis.close() + } val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index e9304f6bb45d0..bac459e835a3f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -73,7 +73,21 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) - blockManager.dataSerializeStream(blockId, outputStream, values) + try { + try { + blockManager.dataSerializeStream(blockId, outputStream, values) + } finally { + // Close outputStream here because it should be closed before file is deleted. + outputStream.close() + } + } catch { + case e: Throwable => + if (file.exists()) { + file.delete() + } + throw e + } + val length = file.length val timeTaken = System.currentTimeMillis - startTime From 1b97a941a09a2f63d442f435c1b444d857cd6956 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 5 Oct 2014 11:19:17 -0700 Subject: [PATCH 193/315] [SPARK-3007][SQL] Fixes dynamic partitioning support for lower Hadoop versions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a follow up of #2226 and #2616 to fix Jenkins master SBT build failures for lower Hadoop versions (1.0.x and 2.0.x). The root cause is the semantics difference of `FileSystem.globStatus()` between different versions of Hadoop, as illustrated by the following test code: ```scala object GlobExperiments extends App { val conf = new Configuration() val fs = FileSystem.getLocal(conf) fs.globStatus(new Path("/tmp/wh/*/*/*")).foreach { status => println(status.getPath) } } ``` Target directory structure: ``` /tmp/wh ├── dir0 │   ├── dir1 │   │   └── level2 │   └── level1 └── level0 ``` Hadoop 2.4.1 result: ``` file:/tmp/wh/dir0/dir1/level2 ``` Hadoop 1.0.4 resuet: ``` file:/tmp/wh/dir0/dir1/level2 file:/tmp/wh/dir0/level1 file:/tmp/wh/level0 ``` In #2226 and #2616, we call `FileOutputCommitter.commitJob()` at the end of the job, and the `_SUCCESS` mark file is written. When working with lower Hadoop versions, due to the `globStatus()` semantics issue, `_SUCCESS` is included as a separate partition data file by `Hive.loadDynamicPartitions()`, and fails partition spec checking. The fix introduced in this PR is kind of a hack: when inserting data with dynamic partitioning, we intentionally avoid writing the `_SUCCESS` marker to workaround this issue. Hive doesn't suffer this issue because `FileSinkOperator` doesn't call `FileOutputCommitter.commitJob()`, instead, it calls `Utilities.mvFileToFinalPath()` to cleanup the output directory and then loads it into Hive warehouse by with `loadDynamicPartitions()`/`loadPartition()`/`loadTable()`. This approach is better because it handles failed job and speculative tasks properly. We should add this step to `InsertIntoHiveTable` in another PR. Author: Cheng Lian Closes #2663 from liancheng/dp-hadoop-1-fix and squashes the following commits: 0177dae [Cheng Lian] Fixes dynamic partitioning support for lower Hadoop versions --- .../spark/sql/hive/hiveWriterContainers.scala | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ac5c7a8220296..6ccbc22a4acfb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -55,8 +55,8 @@ private[hive] class SparkHiveWriterContainer( private var taID: SerializableWritable[TaskAttemptID] = null @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private lazy val committer = conf.value.getOutputCommitter - @transient private lazy val jobContext = newJobContext(conf.value, jID.value) + @transient protected lazy val committer = conf.value.getOutputCommitter + @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) @transient private lazy val outputFormat = conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] @@ -122,8 +122,6 @@ private[hive] class SparkHiveWriterContainer( } } - // ********* Private Functions ********* - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { jobID = jobId splitID = splitId @@ -157,12 +155,18 @@ private[hive] object SparkHiveWriterContainer { } } +private[spark] object SparkHiveDynamicPartitionWriterContainer { + val SUCCESSFUL_JOB_OUTPUT_DIR_MARKER = "mapreduce.fileoutputcommitter.marksuccessfuljobs" +} + private[spark] class SparkHiveDynamicPartitionWriterContainer( @transient jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String]) extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + import SparkHiveDynamicPartitionWriterContainer._ + private val defaultPartName = jobConf.get( ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) @@ -179,6 +183,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( commit() } + override def commitJob(): Unit = { + // This is a hack to avoid writing _SUCCESS mark file. In lower versions of Hadoop (e.g. 1.0.4), + // semantics of FileSystem.globStatus() is different from higher versions (e.g. 2.4.1) and will + // include _SUCCESS file when glob'ing for dynamic partition data files. + // + // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does: + // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then + // load it with loadDynamicPartitions/loadPartition/loadTable. + val oldMarker = jobConf.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) + jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) + super.commitJob() + jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) + } + override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { val dynamicPartPath = dynamicPartColNames .zip(row.takeRight(dynamicPartColNames.length)) From e222221e24c122300bbde6d5ec4002a7c42b2e24 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 5 Oct 2014 13:22:40 -0700 Subject: [PATCH 194/315] HOTFIX: Fix unicode error in merge script. The merge script builds up a big command array and sometimes this contains both unicode and ascii strings. This doesn't work if you try to join them into a single string. Longer term a solution is to go and make sure the source of all strings is unicode. This patch provides a simpler solution... just print the array rather than joining. I actually prefer printing an array here anyways since joining on spaces is lossy in the case of arguments that themselves contain spaces. Author: Patrick Wendell Closes #2645 from pwendell/merge-script and squashes the following commits: 167b792 [Patrick Wendell] HOTFIX: Fix unicode error in merge script. --- dev/merge_spark_pr.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index a8e92e36fe0d8..02ac20984add9 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -73,11 +73,10 @@ def fail(msg): def run_cmd(cmd): + print cmd if isinstance(cmd, list): - print " ".join(cmd) return subprocess.check_output(cmd) else: - print cmd return subprocess.check_output(cmd.split(" ")) From 79b2108de30bf91c8e58bb36405d334aeb2a00ad Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 5 Oct 2014 17:44:38 -0700 Subject: [PATCH 195/315] [Minor] Trivial fix to make codes more readable It should just use `maxResults` there. Author: Liang-Chi Hsieh Closes #2654 from viirya/trivial_fix and squashes the following commits: 1362289 [Liang-Chi Hsieh] Trivial fix to make codes more readable. --- .../src/main/scala/org/apache/spark/sql/hive/HiveContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8bcc098bbb620..fad3b39f81413 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -268,7 +268,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ protected[sql] def runSqlHive(sql: String): Seq[String] = { val maxResults = 100000 - val results = runHive(sql, 100000) + val results = runHive(sql, maxResults) // It is very confusing when you only get back some of the results... if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") results From 58f5361caaa2f898e38ae4b3794167881e20a818 Mon Sep 17 00:00:00 2001 From: scwf Date: Sun, 5 Oct 2014 17:47:20 -0700 Subject: [PATCH 196/315] [SPARK-3792][SQL] Enable JavaHiveQLSuite Do not use TestSQLContext in JavaHiveQLSuite, that may lead to two SparkContexts in one jvm and enable JavaHiveQLSuite Author: scwf Closes #2652 from scwf/fix-JavaHiveQLSuite and squashes the following commits: be35c91 [scwf] enable JavaHiveQLSuite --- .../sql/hive/api/java/JavaHiveQLSuite.scala | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 9644b707eb1a0..46b11b582b26d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -25,34 +25,30 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.TestSQLContext // Implicits import scala.collection.JavaConversions._ class JavaHiveQLSuite extends FunSuite { - lazy val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) + lazy val javaCtx = new JavaSparkContext(TestHive.sparkContext) // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM lazy val javaHiveCtx = new JavaHiveContext(javaCtx) { override val sqlContext = TestHive } - ignore("SELECT * FROM src") { + test("SELECT * FROM src") { assert( javaHiveCtx.sql("SELECT * FROM src").collect().map(_.getInt(0)) === TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) } - private val explainCommandClassName = - classOf[ExplainCommand].getSimpleName.stripSuffix("$") - def isExplanation(result: JavaSchemaRDD) = { val explanation = result.collect().map(_.getString(0)) - explanation.size > 1 && explanation.head.startsWith(explainCommandClassName) + explanation.size > 1 && explanation.head.startsWith("== Physical Plan ==") } - ignore("Query Hive native command execution result") { + test("Query Hive native command execution result") { val tableName = "test_native_commands" assertResult(0) { @@ -63,23 +59,18 @@ class JavaHiveQLSuite extends FunSuite { javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.sql("SHOW TABLES").registerTempTable("show_tables") - assert( javaHiveCtx - .sql("SELECT result FROM show_tables") + .sql("SHOW TABLES") .collect() .map(_.getString(0)) .contains(tableName)) - assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.sql(s"DESCRIBE $tableName").registerTempTable("describe_table") - - + assertResult(Array(Array("key", "int"), Array("value", "string"))) { javaHiveCtx - .sql("SELECT result FROM describe_table") + .sql(s"describe $tableName") .collect() - .map(_.getString(0).split("\t").map(_.trim)) + .map(row => Array(row.get(0).asInstanceOf[String], row.get(1).asInstanceOf[String])) .toArray } @@ -89,7 +80,7 @@ class JavaHiveQLSuite extends FunSuite { TestHive.reset() } - ignore("Exactly once semantics for DDL and command statements") { + test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" val q0 = javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)") From 34b97a067d1b370fbed8ecafab2f48501a35d783 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 5 Oct 2014 17:51:59 -0700 Subject: [PATCH 197/315] [SPARK-3645][SQL] Makes table caching eager by default and adds syntax for lazy caching Although lazy caching for in-memory table seems consistent with the `RDD.cache()` API, it's relatively confusing for users who mainly work with SQL and not familiar with Spark internals. The `CACHE TABLE t; SELECT COUNT(*) FROM t;` pattern is also commonly seen just to ensure predictable performance. This PR makes both the `CACHE TABLE t [AS SELECT ...]` statement and the `SQLContext.cacheTable()` API eager by default, and adds a new `CACHE LAZY TABLE t [AS SELECT ...]` syntax to provide lazy in-memory table caching. Also, took the chance to make some refactoring: `CacheCommand` and `CacheTableAsSelectCommand` are now merged and renamed to `CacheTableCommand` since the former is strictly a special case of the latter. A new `UncacheTableCommand` is added for the `UNCACHE TABLE t` statement. Author: Cheng Lian Closes #2513 from liancheng/eager-caching and squashes the following commits: fe92287 [Cheng Lian] Makes table caching eager by default and adds syntax for lazy caching --- .../apache/spark/sql/catalyst/SqlParser.scala | 45 +++--- .../spark/sql/catalyst/analysis/Catalog.scala | 2 +- .../sql/catalyst/plans/logical/commands.scala | 15 +- .../org/apache/spark/sql/CacheManager.scala | 9 +- .../columnar/InMemoryColumnarTableScan.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../apache/spark/sql/execution/commands.scala | 47 +++--- .../apache/spark/sql/CachedTableSuite.scala | 145 +++++++++++++----- .../spark/sql/hive/ExtendedHiveQlParser.scala | 66 ++++---- .../org/apache/spark/sql/hive/TestHive.scala | 6 +- .../spark/sql/hive/CachedTableSuite.scala | 78 ++++++++-- 11 files changed, 265 insertions(+), 158 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 26336332c05a2..854b5b461bdc8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -67,11 +67,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") + protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") - protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") @@ -80,9 +81,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") - protected val LAST = Keyword("LAST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") @@ -91,42 +92,42 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") protected val INSERT = Keyword("INSERT") + protected val INTERSECT = Keyword("INTERSECT") protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") + protected val LAST = Keyword("LAST") + protected val LAZY = Keyword("LAZY") protected val LEFT = Keyword("LEFT") + protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") + protected val LOWER = Keyword("LOWER") protected val MAX = Keyword("MAX") protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") protected val OR = Keyword("OR") - protected val OVERWRITE = Keyword("OVERWRITE") - protected val LIKE = Keyword("LIKE") - protected val RLIKE = Keyword("RLIKE") - protected val UPPER = Keyword("UPPER") - protected val LOWER = Keyword("LOWER") - protected val REGEXP = Keyword("REGEXP") protected val ORDER = Keyword("ORDER") protected val OUTER = Keyword("OUTER") + protected val OVERWRITE = Keyword("OVERWRITE") + protected val REGEXP = Keyword("REGEXP") protected val RIGHT = Keyword("RIGHT") + protected val RLIKE = Keyword("RLIKE") protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") + protected val SQRT = Keyword("SQRT") protected val STRING = Keyword("STRING") + protected val SUBSTR = Keyword("SUBSTR") + protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") + protected val UPPER = Keyword("UPPER") protected val WHERE = Keyword("WHERE") - protected val INTERSECT = Keyword("INTERSECT") - protected val EXCEPT = Keyword("EXCEPT") - protected val SUBSTR = Keyword("SUBSTR") - protected val SUBSTRING = Keyword("SUBSTRING") - protected val SQRT = Keyword("SQRT") - protected val ABS = Keyword("ABS") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -183,17 +184,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } protected lazy val cache: Parser[LogicalPlan] = - CACHE ~ TABLE ~> ident ~ opt(AS ~> select) <~ opt(";") ^^ { - case tableName ~ None => - CacheCommand(tableName, true) - case tableName ~ Some(plan) => - CacheTableAsSelectCommand(tableName, plan) + CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> select) <~ opt(";") ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan, isLazy.isDefined) } - + protected lazy val unCache: Parser[LogicalPlan] = UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { - case tableName => CacheCommand(tableName, false) - } + case tableName => UncacheTableCommand(tableName) + } protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") @@ -283,7 +282,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { + termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) } | termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 616f1e2ecb60f..2059a91ba0612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -87,7 +87,7 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tableName: String, alias: Option[String] = None): LogicalPlan = { val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName")) + val table = tables.getOrElse(tblName, sys.error(s"Table Not Found: $tableName")) val tableWithQualifiers = Subquery(tblName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 8366639fa0e8b..9a3848cfc6b62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -56,9 +56,15 @@ case class ExplainCommand(plan: LogicalPlan, extended: Boolean = false) extends } /** - * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. + * Returned for the "CACHE TABLE tableName [AS SELECT ...]" command. */ -case class CacheCommand(tableName: String, doCache: Boolean) extends Command +case class CacheTableCommand(tableName: String, plan: Option[LogicalPlan], isLazy: Boolean) + extends Command + +/** + * Returned for the "UNCACHE TABLE tableName" command. + */ +case class UncacheTableCommand(tableName: String) extends Command /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -75,8 +81,3 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } - -/** - * Returned for the "CACHE TABLE tableName AS SELECT .." command. - */ -case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index aebdbb68e49b8..3bf7382ac67a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -91,14 +91,10 @@ private[sql] trait CacheManager { } /** Removes the data for the given SchemaRDD from the cache */ - private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = false): Unit = writeLock { + private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.optimizedPlan val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache)) - - if (dataIndex < 0) { - throw new IllegalArgumentException(s"Table $query is not cached.") - } - + require(dataIndex >= 0, s"Table $query is not cached.") cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) cachedData.remove(dataIndex) } @@ -135,5 +131,4 @@ private[sql] trait CacheManager { case _ => } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index cec82a7f2df94..4f79173a26f88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -111,7 +111,7 @@ private[sql] case class InMemoryRelation( override def newInstance() = { new InMemoryRelation( - output.map(_.newInstance), + output.map(_.newInstance()), useCompression, batchSize, storageLevel, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cf93d5ad7b503..5c16d0c624128 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -304,10 +304,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(execution.SetCommand(key, value, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) - case logical.CacheCommand(tableName, cache) => - Seq(execution.CacheCommand(tableName, cache)(context)) - case logical.CacheTableAsSelectCommand(tableName, plan) => - Seq(execution.CacheTableAsSelectCommand(tableName, plan)) + case logical.CacheTableCommand(tableName, optPlan, isLazy) => + Seq(execution.CacheTableCommand(tableName, optPlan, isLazy)) + case logical.UncacheTableCommand(tableName) => + Seq(execution.UncacheTableCommand(tableName)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index f88099ec0761e..d49633c24ad4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -138,49 +138,54 @@ case class ExplainCommand( * :: DeveloperApi :: */ @DeveloperApi -case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext) +case class CacheTableCommand( + tableName: String, + plan: Option[LogicalPlan], + isLazy: Boolean) extends LeafNode with Command { override protected lazy val sideEffectResult = { - if (doCache) { - context.cacheTable(tableName) - } else { - context.uncacheTable(tableName) + import sqlContext._ + + plan.foreach(_.registerTempTable(tableName)) + val schemaRDD = table(tableName) + schemaRDD.cache() + + if (!isLazy) { + // Performs eager caching + schemaRDD.count() } + Seq.empty[Row] } override def output: Seq[Attribute] = Seq.empty } + /** * :: DeveloperApi :: */ @DeveloperApi -case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( - @transient context: SQLContext) - extends LeafNode with Command { - +case class UncacheTableCommand(tableName: String) extends LeafNode with Command { override protected lazy val sideEffectResult: Seq[Row] = { - Row("# Registered as a temporary table", null, null) +: - child.output.map(field => Row(field.name, field.dataType.toString, null)) + sqlContext.table(tableName).unpersist() + Seq.empty[Row] } + + override def output: Seq[Attribute] = Seq.empty } /** * :: DeveloperApi :: */ @DeveloperApi -case class CacheTableAsSelectCommand(tableName: String, logicalPlan: LogicalPlan) +case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) extends LeafNode with Command { - - override protected[sql] lazy val sideEffectResult = { - import sqlContext._ - logicalPlan.registerTempTable(tableName) - cacheTable(tableName) - Seq.empty[Row] - } - override def output: Seq[Attribute] = Seq.empty - + override protected lazy val sideEffectResult: Seq[Row] = { + Row("# Registered as a temporary table", null, null) +: + child.output.map(field => Row(field.name, field.dataType.toString, null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 957388e99bd85..1e624f97004f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,30 +18,39 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.storage.RDDBlockId case class BigData(s: String) class CachedTableSuite extends QueryTest { - import TestSQLContext._ TestData // Load test tables. - /** - * Throws a test failed exception when the number of cached tables differs from the expected - * number. - */ def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached } - if (cachedData.size != numCachedTables) { - fail( - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + planWithCaching) - } + } + + def rddIdOf(tableName: String): Int = { + val executedPlan = table(tableName).queryExecution.executedPlan + executedPlan.collect { + case InMemoryColumnarTableScan(_, _, relation) => + relation.cachedColumnBuffers.id + case _ => + fail(s"Table $tableName is not cached\n" + executedPlan) + }.head + } + + def isMaterialized(rddId: Int): Boolean = { + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("too big for memory") { @@ -52,10 +61,33 @@ class CachedTableSuite extends QueryTest { uncacheTable("bigData") } - test("calling .cache() should use inmemory columnar caching") { + test("calling .cache() should use in-memory columnar caching") { table("testData").cache() + assertCached(table("testData")) + } + + test("calling .unpersist() should drop in-memory columnar cache") { + table("testData").cache() + table("testData").count() + table("testData").unpersist(true) + assertCached(table("testData"), 0) + } + + test("isCached") { + cacheTable("testData") assertCached(table("testData")) + assert(table("testData").queryExecution.withCachedData match { + case _: InMemoryRelation => true + case _ => false + }) + + uncacheTable("testData") + assert(!isCached("testData")) + assert(table("testData").queryExecution.withCachedData match { + case _: InMemoryRelation => false + case _ => true + }) } test("SPARK-1669: cacheTable should be idempotent") { @@ -64,32 +96,27 @@ class CachedTableSuite extends QueryTest { cacheTable("testData") assertCached(table("testData")) - cacheTable("testData") - table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => - fail("cacheTable is not idempotent") + assertResult(1, "InMemoryRelation not found, testData should have been cached") { + table("testData").queryExecution.withCachedData.collect { + case r: InMemoryRelation => r + }.size + } - case _ => + cacheTable("testData") + assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { + table("testData").queryExecution.withCachedData.collect { + case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => r + }.size } } test("read from cached table and uncache") { cacheTable("testData") - - checkAnswer( - table("testData"), - testData.collect().toSeq - ) - + checkAnswer(table("testData"), testData.collect().toSeq) assertCached(table("testData")) uncacheTable("testData") - - checkAnswer( - table("testData"), - testData.collect().toSeq - ) - + checkAnswer(table("testData"), testData.collect().toSeq) assertCached(table("testData"), 0) } @@ -99,10 +126,12 @@ class CachedTableSuite extends QueryTest { } } - test("SELECT Star Cached Table") { + test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") cacheTable("selectStar") - sql("SELECT * FROM selectStar WHERE key = 1").collect() + checkAnswer( + sql("SELECT * FROM selectStar WHERE key = 1"), + Seq(Row(1, "1"))) uncacheTable("selectStar") } @@ -120,23 +149,57 @@ class CachedTableSuite extends QueryTest { sql("CACHE TABLE testData") assertCached(table("testData")) - assert(isCached("testData"), "Table 'testData' should be cached") + val rddId = rddIdOf("testData") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assertCached(table("testData"), 0) assert(!isCached("testData"), "Table 'testData' should not be cached") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } - - test("CACHE TABLE tableName AS SELECT Star Table") { + + test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - sql("SELECT * FROM testCacheTable WHERE key = 1").collect() - assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + assertCached(table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } - - test("'CACHE TABLE tableName AS SELECT ..'") { - sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + + test("CACHE TABLE tableName AS SELECT ...") { + sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") + assertCached(table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } + + test("CACHE LAZY TABLE tableName") { + sql("CACHE LAZY TABLE testData") + assertCached(table("testData")) + + val rddId = rddIdOf("testData") + assert( + !isMaterialized(rddId), + "Lazily cached in-memory table shouldn't be materialized eagerly") + + sql("SELECT COUNT(*) FROM testData").collect() + assert( + isMaterialized(rddId), + "Lazily cached in-memory table should have been materialized") + + uncacheTable("testData") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index e7e1cb980c2ae..c5844e92eaaa9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical /** - * A parser that recognizes all HiveQL constructs together with several Spark SQL specific + * A parser that recognizes all HiveQL constructs together with several Spark SQL specific * extensions like CACHE TABLE and UNCACHE TABLE. */ -private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { - +private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { + def apply(input: String): LogicalPlan = { // Special-case out set commands since the value fields can be // complex to handle without RegexParsers. Also this approach @@ -54,16 +54,17 @@ private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with Packr protected case class Keyword(str: String) - protected val CACHE = Keyword("CACHE") - protected val SET = Keyword("SET") protected val ADD = Keyword("ADD") - protected val JAR = Keyword("JAR") - protected val TABLE = Keyword("TABLE") protected val AS = Keyword("AS") - protected val UNCACHE = Keyword("UNCACHE") - protected val FILE = Keyword("FILE") + protected val CACHE = Keyword("CACHE") protected val DFS = Keyword("DFS") + protected val FILE = Keyword("FILE") + protected val JAR = Keyword("JAR") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") protected val SOURCE = Keyword("SOURCE") + protected val TABLE = Keyword("TABLE") + protected val UNCACHE = Keyword("UNCACHE") protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) @@ -79,57 +80,56 @@ private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with Packr override val lexical = new SqlLexical(reservedWords) - protected lazy val query: Parser[LogicalPlan] = + protected lazy val query: Parser[LogicalPlan] = cache | uncache | addJar | addFile | dfs | source | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = - remainingQuery ^^ { - case r => HiveQl.createPlan(r.trim()) + restInput ^^ { + case statement => HiveQl.createPlan(statement.trim()) } - /** It returns all remaining query */ - protected lazy val remainingQuery: Parser[String] = new Parser[String] { + // Returns the whole input string + protected lazy val wholeInput: Parser[String] = new Parser[String] { def apply(in: Input) = - Success( - in.source.subSequence(in.offset, in.source.length).toString, - in.drop(in.source.length())) + Success(in.source.toString, in.drop(in.source.length())) } - /** It returns all query */ - protected lazy val allQuery: Parser[String] = new Parser[String] { + // Returns the rest of the input string that are not parsed yet + protected lazy val restInput: Parser[String] = new Parser[String] { def apply(in: Input) = - Success(in.source.toString, in.drop(in.source.length())) + Success( + in.source.subSequence(in.offset, in.source.length).toString, + in.drop(in.source.length())) } protected lazy val cache: Parser[LogicalPlan] = - CACHE ~ TABLE ~> ident ~ opt(AS ~> hiveQl) ^^ { - case tableName ~ None => CacheCommand(tableName, true) - case tableName ~ Some(plan) => - CacheTableAsSelectCommand(tableName, plan) + CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> hiveQl) ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan, isLazy.isDefined) } protected lazy val uncache: Parser[LogicalPlan] = UNCACHE ~ TABLE ~> ident ^^ { - case tableName => CacheCommand(tableName, false) + case tableName => UncacheTableCommand(tableName) } protected lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> remainingQuery ^^ { - case rq => AddJar(rq.trim()) + ADD ~ JAR ~> restInput ^^ { + case jar => AddJar(jar.trim()) } protected lazy val addFile: Parser[LogicalPlan] = - ADD ~ FILE ~> remainingQuery ^^ { - case rq => AddFile(rq.trim()) + ADD ~ FILE ~> restInput ^^ { + case file => AddFile(file.trim()) } protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> allQuery ^^ { - case aq => NativeCommand(aq.trim()) + DFS ~> wholeInput ^^ { + case command => NativeCommand(command.trim()) } protected lazy val source: Parser[LogicalPlan] = - SOURCE ~> remainingQuery ^^ { - case rq => SourceCommand(rq.trim()) + SOURCE ~> restInput ^^ { + case file => SourceCommand(file.trim()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c0e69393cc2e3..a4354c1379c63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} +import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.SQLConf @@ -67,7 +67,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath /** Sets up the system initially or after a RESET command */ - protected def configure() { + protected def configure(): Unit = { setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") setConf("hive.metastore.warehouse.dir", warehousePath) @@ -154,7 +154,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override lazy val analyzed = { val describedTables = logical match { case NativeCommand(describedTable(tbl)) => tbl :: Nil - case CacheCommand(tbl, _) => tbl :: Nil + case CacheTableCommand(tbl, _, _) => tbl :: Nil case _ => Nil } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 158cfb5bbee7c..2060e1f1a7a4b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{QueryTest, SchemaRDD} -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{QueryTest, SchemaRDD} +import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { - import TestHive._ - /** * Throws a test failed exception when the number of cached tables differs from the expected * number. @@ -34,11 +34,24 @@ class CachedTableSuite extends QueryTest { case cached: InMemoryRelation => cached } - if (cachedData.size != numCachedTables) { - fail( - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + + def rddIdOf(tableName: String): Int = { + val executedPlan = table(tableName).queryExecution.executedPlan + executedPlan.collect { + case InMemoryColumnarTableScan(_, _, relation) => + relation.cachedColumnBuffers.id + case _ => + fail(s"Table $tableName is not cached\n" + executedPlan) + }.head + } + + def isMaterialized(rddId: Int): Boolean = { + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache table") { @@ -102,16 +115,47 @@ class CachedTableSuite extends QueryTest { assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } - test("CACHE TABLE AS SELECT") { - assertCached(sql("SELECT * FROM src"), 0) - sql("CACHE TABLE test AS SELECT key FROM src") + test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { + sql("CACHE TABLE testCacheTable AS SELECT * FROM src") + assertCached(table("testCacheTable")) - checkAnswer( - sql("SELECT * FROM test"), - sql("SELECT key FROM src").collect().toSeq) + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") - assertCached(sql("SELECT * FROM test")) + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } + + test("CACHE TABLE tableName AS SELECT ...") { + sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") + assertCached(table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } - assertCached(sql("SELECT * FROM test JOIN test"), 2) + test("CACHE LAZY TABLE tableName") { + sql("CACHE LAZY TABLE src") + assertCached(table("src")) + + val rddId = rddIdOf("src") + assert( + !isMaterialized(rddId), + "Lazily cached in-memory table shouldn't be materialized eagerly") + + sql("SELECT COUNT(*) FROM src").collect() + assert( + isMaterialized(rddId), + "Lazily cached in-memory table should have been materialized") + + uncacheTable("src") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } } From 90897ea5f24b03c9f3455a62c7f68b3d3f0435ad Mon Sep 17 00:00:00 2001 From: Renat Yusupov Date: Sun, 5 Oct 2014 17:56:24 -0700 Subject: [PATCH 198/315] [SPARK-3776][SQL] Wrong conversion to Catalyst for Option[Product] Author: Renat Yusupov Closes #2641 from r3natko/feature/catalyst_option and squashes the following commits: 55d0c06 [Renat Yusupov] [SQL] SPARK-3776: Wrong conversion to Catalyst for Option[Product] --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/ScalaReflectionSuite.scala | 21 ++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 88a8fa7c28e0f..b3ae8e6779700 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -33,7 +33,7 @@ object ScalaReflection { /** Converts Scala objects to catalyst rows / types */ def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull + case o: Option[_] => o.map(convertToCatalyst).orNull case s: Seq[_] => s.map(convertToCatalyst) case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 428607d8c8253..488e373854bb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -53,7 +53,8 @@ case class OptionalData( floatField: Option[Float], shortField: Option[Short], byteField: Option[Byte], - booleanField: Option[Boolean]) + booleanField: Option[Boolean], + structField: Option[PrimitiveData]) case class ComplexData( arrayField: Seq[Int], @@ -100,7 +101,7 @@ class ScalaReflectionSuite extends FunSuite { nullable = true)) } - test("optinal data") { + test("optional data") { val schema = schemaFor[OptionalData] assert(schema === Schema( StructType(Seq( @@ -110,7 +111,8 @@ class ScalaReflectionSuite extends FunSuite { StructField("floatField", FloatType, nullable = true), StructField("shortField", ShortType, nullable = true), StructField("byteField", ByteType, nullable = true), - StructField("booleanField", BooleanType, nullable = true))), + StructField("booleanField", BooleanType, nullable = true), + StructField("structField", schemaFor[PrimitiveData].dataType, nullable = true))), nullable = true)) } @@ -228,4 +230,17 @@ class ScalaReflectionSuite extends FunSuite { assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) } + + test("convert PrimitiveData to catalyst") { + val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) + val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + assert(convertToCatalyst(data) === convertedData) + } + + test("convert Option[Product] to catalyst") { + val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) + val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData)) + val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData)) + assert(convertToCatalyst(data) === convertedData) + } } From 8d22dbb5ec7a0727afdfebbbc2c57ffdb384dd0b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 5 Oct 2014 18:44:12 -0700 Subject: [PATCH 199/315] SPARK-3794 [CORE] Building spark core fails due to inadvertent dependency on Commons IO Remove references to Commons IO FileUtils and replace with pure Java version, which doesn't need to traverse the whole directory tree first. I think this method could be refined further if it would be alright to rename it and its args and break it down into two methods. I'm starting with a simple recursive rendition. Author: Sean Owen Closes #2662 from srowen/SPARK-3794 and squashes the following commits: 4cd172f [Sean Owen] Remove references to Commons IO FileUtils and replace with pure Java version, which doesn't need to traverse the whole directory tree first --- .../apache/spark/deploy/worker/Worker.scala | 1 - .../scala/org/apache/spark/util/Utils.scala | 20 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3b13f43a1868c..9b52cb06fb6fa 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -29,7 +29,6 @@ import scala.language.postfixOps import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.commons.io.FileUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a67124140f9da..3d307b3c16d3e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -35,8 +35,6 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.TrueFileFilter import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.log4j.PropertyConfigurator @@ -710,18 +708,20 @@ private[spark] object Utils extends Logging { * Determines if a directory contains any files newer than cutoff seconds. * * @param dir must be the path to a directory, or IllegalArgumentException is thrown - * @param cutoff measured in seconds. Returns true if there are any files in dir newer than this. + * @param cutoff measured in seconds. Returns true if there are any files or directories in the + * given directory whose last modified time is later than this many seconds ago */ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { - val currentTimeMillis = System.currentTimeMillis if (!dir.isDirectory) { - throw new IllegalArgumentException (dir + " is not a directory!") - } else { - val files = FileUtils.listFilesAndDirs(dir, TrueFileFilter.TRUE, TrueFileFilter.TRUE) - val cutoffTimeInMillis = (currentTimeMillis - (cutoff * 1000)) - val newFiles = files.filter { _.lastModified > cutoffTimeInMillis } - newFiles.nonEmpty + throw new IllegalArgumentException("$dir is not a directory!") } + val filesAndDirs = dir.listFiles() + val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000) + + filesAndDirs.exists(_.lastModified() > cutoffTimeInMillis) || + filesAndDirs.filter(_.isDirectory).exists( + subdir => doesDirectoryContainAnyNewFiles(subdir, cutoff) + ) } /** From fd7b15539669b14996a51610d6724ca0811f9d65 Mon Sep 17 00:00:00 2001 From: Nathan Kronenfeld Date: Sun, 5 Oct 2014 21:03:48 -0700 Subject: [PATCH 200/315] Rectify gereneric parameter names between SparkContext and AccumulablePa... AccumulableParam gave its generic parameters as 'R, T', whereas SparkContext labeled them 'T, R'. Trivial, but really confusing. I resolved this in favor of AccumulableParam, because it seemed to have some logic for its names. I also extended this minimal, but at least present, justification into the SparkContext comments. Author: Nathan Kronenfeld Closes #2637 from nkronenfeld/accumulators and squashes the following commits: 98d6b74 [Nathan Kronenfeld] Rectify gereneric parameter names between SparkContext and AccumulableParam --- .../main/scala/org/apache/spark/SparkContext.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 97109b9f41b60..396cdd1247e07 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -779,20 +779,20 @@ class SparkContext(config: SparkConf) extends Logging { /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator + * @tparam R accumulator result type + * @tparam T type that can be added to the accumulator */ - def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = new Accumulable(initialValue, param) /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can * access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator + * @tparam R accumulator result type + * @tparam T type that can be added to the accumulator */ - def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) = + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = new Accumulable(initialValue, param, Some(name)) /** From c9ae79fba25cd49ca70ca398bc75434202d26a97 Mon Sep 17 00:00:00 2001 From: scwf Date: Sun, 5 Oct 2014 21:36:20 -0700 Subject: [PATCH 201/315] [SPARK-3765][Doc] Add test information to sbt build docs Add testing with sbt to doc ```building-spark.md``` Author: scwf Closes #2629 from scwf/sbt-doc and squashes the following commits: fd9cf29 [scwf] add testing with sbt to docs --- docs/building-spark.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/building-spark.md b/docs/building-spark.md index 901c157162fee..b2940ee4029e8 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -171,6 +171,21 @@ can be set to control the SBT build. For example: sbt/sbt -Pyarn -Phadoop-2.3 assembly +# Testing with SBT + +Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive test + +To run only a specific test suite as follows: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" + +To run test suites of a specific sub project as follows: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test + # Speeding up Compilation with Zinc [Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental From 20ea54cc7a5176ebc63bfa9393a9bf84619bfc66 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 6 Oct 2014 14:05:45 -0700 Subject: [PATCH 202/315] [SPARK-2461] [PySpark] Add a toString method to GeneralizedLinearModel Add a toString method to GeneralizedLinearModel, also change `__str__` to `__repr__` for some classes, to provide better message in repr. This PR is based on #1388, thanks to sryza! closes #1388 Author: Sandy Ryza Author: Davies Liu Closes #2625 from davies/string and squashes the following commits: 3544aad [Davies Liu] fix LinearModel 0bcd642 [Davies Liu] Merge branch 'sandy-spark-2461' of github.com:sryza/spark 1ce5c2d [Sandy Ryza] __repr__ back to __str__ in a couple places aa9e962 [Sandy Ryza] Switch __str__ to __repr__ a0c5041 [Sandy Ryza] Add labels back in 1aa17f5 [Sandy Ryza] Match existing conventions fac1bc4 [Sandy Ryza] Fix PEP8 error f7b58ed [Sandy Ryza] SPARK-2461. Add a toString method to GeneralizedLinearModel --- .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 2 ++ python/pyspark/mllib/regression.py | 3 +++ python/pyspark/serializers.py | 6 +++--- python/pyspark/sql.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index d0fe4179685ca..00dfc86c9e0bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -75,6 +75,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } + + override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept) } /** diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index cbdbc09858013..8fe8c6db2ad9c 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -66,6 +66,9 @@ def weights(self): def intercept(self): return self._intercept + def __repr__(self): + return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept) + class LinearRegressionModelBase(LinearModel): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2672da36c1f50..099fa54cf2bd7 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -211,7 +211,7 @@ def __eq__(self, other): return (isinstance(other, BatchedSerializer) and other.serializer == self.serializer) - def __str__(self): + def __repr__(self): return "BatchedSerializer<%s>" % str(self.serializer) @@ -279,7 +279,7 @@ def __eq__(self, other): return (isinstance(other, CartesianDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): + def __repr__(self): return "CartesianDeserializer<%s, %s>" % \ (str(self.key_ser), str(self.val_ser)) @@ -306,7 +306,7 @@ def __eq__(self, other): return (isinstance(other, PairDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): + def __repr__(self): return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 974b5e287bc00..114644ab8b79d 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -201,7 +201,7 @@ def __init__(self, elementType, containsNull=True): self.elementType = elementType self.containsNull = containsNull - def __str__(self): + def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) From 4f01265f7d62e070ba42c251255e385644c1b16c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 6 Oct 2014 14:07:53 -0700 Subject: [PATCH 203/315] [SPARK-3786] [PySpark] speedup tests This patch try to speed up tests of PySpark, re-use the SparkContext in tests.py and mllib/tests.py to reduce the overhead of create SparkContext, remove some test cases, which did not make sense. It also improve the performance of some cases, such as MergerTests and SortTests. before this patch: real 21m27.320s user 4m42.967s sys 0m17.343s after this patch: real 9m47.541s user 2m12.947s sys 0m14.543s It almost cut the time by half. Author: Davies Liu Closes #2646 from davies/tests and squashes the following commits: c54de60 [Davies Liu] revert change about memory limit 6a2a4b0 [Davies Liu] refactor of tests, speedup 100% --- python/pyspark/mllib/tests.py | 2 +- python/pyspark/shuffle.py | 5 +- python/pyspark/tests.py | 92 ++++++++++++++++------------------- python/run-tests | 74 ++++++++++++++-------------- 4 files changed, 82 insertions(+), 91 deletions(-) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f72e88ba6e2ba..5c20e100e144f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -32,7 +32,7 @@ from pyspark.serializers import PickleSerializer from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint -from pyspark.tests import PySparkTestCase +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase _have_scipy = False diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index ce597cbe91e15..d57a802e4734a 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -396,7 +396,6 @@ def _external_items(self): for v in self.data.iteritems(): yield v self.data.clear() - gc.collect() # remove the merged partition for j in range(self.spills): @@ -428,7 +427,7 @@ def _recursive_merged_items(self, start): subdirs = [os.path.join(d, "parts", str(i)) for d in self.localdirs] m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) + subdirs, self.scale * self.partitions, self.partitions) m.pdata = [{} for _ in range(self.partitions)] limit = self._next_limit() @@ -486,7 +485,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch = 10 + batch = 100 chunks, current_chunk = [], [] iterator = iter(iterator) while True: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6fb6bc998c752..7f05d48ade2b3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -67,10 +67,10 @@ SPARK_HOME = os.environ["SPARK_HOME"] -class TestMerger(unittest.TestCase): +class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 16 + self.N = 1 << 14 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -115,7 +115,7 @@ def test_medium_dataset(self): sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), @@ -123,7 +123,7 @@ def test_huge_dataset(self): m._cleanup() -class TestSorter(unittest.TestCase): +class SorterTests(unittest.TestCase): def test_in_memory_sort(self): l = range(1024) random.shuffle(l) @@ -244,16 +244,25 @@ def tearDown(self): sys.path = self._old_sys_path -class TestCheckpoint(PySparkTestCase): +class ReusedPySparkTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + +class CheckpointTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.checkpointDir.name) self.sc.setCheckpointDir(self.checkpointDir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): @@ -288,7 +297,7 @@ def test_checkpoint_and_restore(self): self.assertEquals([1, 2, 3, 4], recovered.collect()) -class TestAddFile(PySparkTestCase): +class AddFileTests(PySparkTestCase): def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -354,7 +363,7 @@ def func(x): self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) -class TestRDDFunctions(PySparkTestCase): +class RDDTests(ReusedPySparkTestCase): def test_id(self): rdd = self.sc.parallelize(range(10)) @@ -365,12 +374,6 @@ def test_id(self): self.assertEqual(id + 1, id2) self.assertEqual(id2, rdd2.id()) - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.sc.stop() - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - self.sc = SparkContext("local") - def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -636,7 +639,7 @@ def test_distinct(self): self.assertEquals(result.count(), 3) -class TestProfiler(PySparkTestCase): +class ProfilerTests(PySparkTestCase): def setUp(self): self._old_sys_path = list(sys.path) @@ -666,10 +669,9 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) -class TestSQL(PySparkTestCase): +class SQLTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.sqlCtx = SQLContext(self.sc) def test_udf(self): @@ -754,27 +756,19 @@ def test_serialize_nested_array_and_map(self): self.assertEqual("2", row.d) -class TestIO(PySparkTestCase): - - def test_stdout_redirection(self): - import subprocess - - def func(x): - subprocess.check_call('ls', shell=True) - self.sc.parallelize([1]).foreach(func) +class InputFormatTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) -class TestInputFormat(PySparkTestCase): - - def setUp(self): - PySparkTestCase.setUp(self) - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc) - - def tearDown(self): - PySparkTestCase.tearDown(self) - shutil.rmtree(self.tempdir.name) + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) def test_sequencefiles(self): basepath = self.tempdir.name @@ -954,15 +948,13 @@ def test_converters(self): self.assertEqual(maps, em) -class TestOutputFormat(PySparkTestCase): +class OutputFormatTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.tempdir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.tempdir.name, ignore_errors=True) def test_sequencefiles(self): @@ -1243,8 +1235,7 @@ def test_malformed_RDD(self): basepath + "/malformed/sequence")) -class TestDaemon(unittest.TestCase): - +class DaemonTests(unittest.TestCase): def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) @@ -1290,7 +1281,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class TestWorker(PySparkTestCase): +class WorkerTests(PySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) @@ -1342,11 +1333,6 @@ def run(): rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) - def test_fd_leak(self): - N = 1100 # fd limit is 1024 by default - rdd = self.sc.parallelize(range(N), N) - self.assertEquals(N, rdd.count()) - def test_after_exception(self): def raise_exception(_): raise Exception() @@ -1379,7 +1365,7 @@ def test_accumulator_when_reuse_worker(self): self.assertEqual(sum(range(100)), acc1.value) -class TestSparkSubmit(unittest.TestCase): +class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() @@ -1492,6 +1478,8 @@ def test_single_script_on_cluster(self): |sc = SparkContext() |print sc.parallelize([1, 2, 3]).map(foo).collect() """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf proc = subprocess.Popen( [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) @@ -1500,7 +1488,11 @@ def test_single_script_on_cluster(self): self.assertIn("[2, 4, 6]", out) -class ContextStopTests(unittest.TestCase): +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) def test_stop(self): sc = SparkContext() diff --git a/python/run-tests b/python/run-tests index a7ec270c7da21..c713861eb77bb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -34,7 +34,7 @@ rm -rf metastore warehouse function run_test() { echo "Running test: $1" - SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -48,6 +48,37 @@ function run_test() { fi } +function run_core_tests() { + echo "Run core tests ..." + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py" + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" +} + +function run_sql_tests() { + echo "Run sql tests ..." + run_test "pyspark/sql.py" +} + +function run_mllib_tests() { + echo "Run mllib tests ..." + run_test "pyspark/mllib/classification.py" + run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/linalg.py" + run_test "pyspark/mllib/random.py" + run_test "pyspark/mllib/recommendation.py" + run_test "pyspark/mllib/regression.py" + run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/tree.py" + run_test "pyspark/mllib/util.py" + run_test "pyspark/mllib/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -60,29 +91,9 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_test "pyspark/rdd.py" -run_test "pyspark/context.py" -run_test "pyspark/conf.py" -run_test "pyspark/sql.py" -# These tests are included in the module-level docs, and so must -# be handled on a higher level rather than within the python file. -export PYSPARK_DOC_TEST=1 -run_test "pyspark/broadcast.py" -run_test "pyspark/accumulators.py" -run_test "pyspark/serializers.py" -unset PYSPARK_DOC_TEST -run_test "pyspark/shuffle.py" -run_test "pyspark/tests.py" -run_test "pyspark/mllib/classification.py" -run_test "pyspark/mllib/clustering.py" -run_test "pyspark/mllib/linalg.py" -run_test "pyspark/mllib/random.py" -run_test "pyspark/mllib/recommendation.py" -run_test "pyspark/mllib/regression.py" -run_test "pyspark/mllib/stat.py" -run_test "pyspark/mllib/tests.py" -run_test "pyspark/mllib/tree.py" -run_test "pyspark/mllib/util.py" +run_core_tests +run_sql_tests +run_mllib_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -90,19 +101,8 @@ if [ $(which pypy) ]; then echo "Testing with PyPy version:" $PYSPARK_PYTHON --version - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - run_test "pyspark/sql.py" - # These tests are included in the module-level docs, and so must - # be handled on a higher level rather than within the python file. - export PYSPARK_DOC_TEST=1 - run_test "pyspark/broadcast.py" - run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - unset PYSPARK_DOC_TEST - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_core_tests + run_sql_tests fi if [[ $FAILED == 0 ]]; then From 2300eb58ae79a86e65b3ff608a578f5d4c09892b Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Mon, 6 Oct 2014 14:08:40 -0700 Subject: [PATCH 204/315] [SPARK-3773][PySpark][Doc] Sphinx build warning When building Sphinx documents for PySpark, we have 12 warnings. Their causes are almost docstrings in broken ReST format. To reproduce this issue, we should run following commands on the commit: 6e27cb630de69fa5acb510b4e2f6b980742b1957. ```bash $ cd ./python/docs $ make clean html ... /Users//MyRepos/Scala/spark/python/pyspark/__init__.py:docstring of pyspark.SparkContext.sequenceFile:4: ERROR: Unexpected indentation. /Users//MyRepos/Scala/spark/python/pyspark/__init__.py:docstring of pyspark.RDD.saveAsSequenceFile:4: ERROR: Unexpected indentation. /Users//MyRepos/Scala/spark/python/pyspark/mllib/classification.py:docstring of pyspark.mllib.classification.LogisticRegressionWithSGD.train:14: ERROR: Unexpected indentation. /Users//MyRepos/Scala/spark/python/pyspark/mllib/classification.py:docstring of pyspark.mllib.classification.LogisticRegressionWithSGD.train:16: WARNING: Definition list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/mllib/classification.py:docstring of pyspark.mllib.classification.LogisticRegressionWithSGD.train:17: WARNING: Block quote ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/mllib/classification.py:docstring of pyspark.mllib.classification.SVMWithSGD.train:14: ERROR: Unexpected indentation. /Users//MyRepos/Scala/spark/python/pyspark/mllib/classification.py:docstring of pyspark.mllib.classification.SVMWithSGD.train:16: WARNING: Definition list ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/pyspark/mllib/classification.py:docstring of pyspark.mllib.classification.SVMWithSGD.train:17: WARNING: Block quote ends without a blank line; unexpected unindent. /Users//MyRepos/Scala/spark/python/docs/pyspark.mllib.rst:50: WARNING: missing attribute mentioned in :members: or __all__: module pyspark.mllib.regression, attribute RidgeRegressionModelLinearRegressionWithSGD /Users//MyRepos/Scala/spark/python/pyspark/mllib/tree.py:docstring of pyspark.mllib.tree.DecisionTreeModel.predict:3: ERROR: Unexpected indentation. ... checking consistency... /Users//MyRepos/Scala/spark/python/docs/modules.rst:: WARNING: document isn't included in any toctree ... copying static files... WARNING: html_static_path entry u'/Users//MyRepos/Scala/spark/python/docs/_static' does not exist ... build succeeded, 12 warnings. ``` Author: cocoatomo Closes #2653 from cocoatomo/issues/3773-sphinx-build-warnings and squashes the following commits: 6f65661 [cocoatomo] [SPARK-3773][PySpark][Doc] Sphinx build warning --- python/docs/modules.rst | 7 ------- python/pyspark/context.py | 1 + python/pyspark/mllib/classification.py | 26 ++++++++++++++++---------- python/pyspark/mllib/regression.py | 15 +++++++++------ python/pyspark/mllib/tree.py | 1 + python/pyspark/rdd.py | 1 + 6 files changed, 28 insertions(+), 23 deletions(-) delete mode 100644 python/docs/modules.rst diff --git a/python/docs/modules.rst b/python/docs/modules.rst deleted file mode 100644 index 183564659fbcf..0000000000000 --- a/python/docs/modules.rst +++ /dev/null @@ -1,7 +0,0 @@ -. -= - -.. toctree:: - :maxdepth: 4 - - pyspark diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e9418320ff781..a45d79d6424c7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -410,6 +410,7 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. The mechanism is as follows: + 1. A Java RDD is created from the SequenceFile or other InputFormat, and the key and value Writable classes 2. Serialization is attempted via Pyrolite pickling diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index ac142fb49a90c..a765b1c4f7d87 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -89,11 +89,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @param regParam: The regularizer parameter (default: 1.0). @param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater + - "l2" for using SquaredL2Updater + - "none" for no regularizer + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features @@ -158,11 +161,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, @param initialWeights: The initial weights (default: None). @param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater + - "l2" for using SquaredL2Updater, + - "none" for no regularizer. + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 8fe8c6db2ad9c..54f34a98337ca 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -22,7 +22,7 @@ from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' +__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -155,11 +155,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, @param regParam: The regularizer parameter (default: 1.0). @param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater, + - "l2" for using SquaredL2Updater, + - "none" for no regularizer. + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index afdcdbdf3ae01..5d7abfb96b7fe 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -48,6 +48,7 @@ def __del__(self): def predict(self, x): """ Predict the label of one or more examples. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index dc6497772e502..e77669aad76b6 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1208,6 +1208,7 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file system, using the L{org.apache.hadoop.io.Writable} types that we convert from the RDD's key and value types. The mechanism is as follows: + 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. 2. Keys and values of this Java RDD are converted to Writables and written out. From 69c3f441a9b6e942d6c08afecd59a0349d61cc7b Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 6 Oct 2014 14:19:06 -0700 Subject: [PATCH 205/315] [SPARK-3479] [Build] Report failed test category This PR allows SparkQA (i.e. Jenkins) to report in its posts to GitHub what category of test failed, if one can be determined. The failure categories are: * general failure * RAT checks failed * Scala style checks failed * Python style checks failed * Build failed * Spark unit tests failed * PySpark unit tests failed * MiMa checks failed This PR also fixes the diffing logic used to determine if a patch introduces new classes. Author: Nicholas Chammas Closes #2606 from nchammas/report-failed-test-category and squashes the following commits: d67df03 [Nicholas Chammas] report what test category failed --- dev/run-tests | 32 ++++++++++++- dev/run-tests-codes.sh | 27 +++++++++++ dev/run-tests-jenkins | 102 ++++++++++++++++++++++++++++------------- 3 files changed, 126 insertions(+), 35 deletions(-) create mode 100644 dev/run-tests-codes.sh diff --git a/dev/run-tests b/dev/run-tests index c3d8f49cdd993..4be2baaf48cd1 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -24,6 +24,16 @@ cd "$FWDIR" # Remove work directory rm -rf ./work +source "$FWDIR/dev/run-tests-codes.sh" + +CURRENT_BLOCK=$BLOCK_GENERAL + +function handle_error () { + echo "[error] Got a return code of $? on line $1 of the run-tests script." + exit $CURRENT_BLOCK +} + + # Build against the right verison of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then @@ -91,26 +101,34 @@ if [ -n "$AMPLAB_JENKINS" ]; then fi fi -# Fail fast -set -e set -o pipefail +trap 'handle_error $LINENO' ERR echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_RAT + ./dev/check-license echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SCALA_STYLE + ./dev/lint-scala echo "" echo "=========================================================================" echo "Running Python style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYTHON_STYLE + ./dev/lint-python echo "" @@ -118,6 +136,8 @@ echo "=========================================================================" echo "Building Spark" echo "=========================================================================" +CURRENT_BLOCK=$BLOCK_BUILD + { # We always build with Hive because the PySpark Spark SQL tests need it. BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" @@ -141,6 +161,8 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" +CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS + { # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. # This must be a single argument, as it is. @@ -175,10 +197,16 @@ echo "" echo "=========================================================================" echo "Running PySpark tests" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS + ./python/run-tests echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_MIMA + ./dev/mima diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh new file mode 100644 index 0000000000000..1348e0609dda4 --- /dev/null +++ b/dev/run-tests-codes.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +readonly BLOCK_GENERAL=10 +readonly BLOCK_RAT=11 +readonly BLOCK_SCALA_STYLE=12 +readonly BLOCK_PYTHON_STYLE=13 +readonly BLOCK_BUILD=14 +readonly BLOCK_SPARK_UNIT_TESTS=15 +readonly BLOCK_PYSPARK_UNIT_TESTS=16 +readonly BLOCK_MIMA=17 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 0b1e31b9413cf..451f3b771cc76 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -26,9 +26,23 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" +source "$FWDIR/dev/run-tests-codes.sh" + COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" +# Important Environment Variables +# --- +# $ghprbActualCommit +#+ This is the hash of the most recent commit in the PR. +#+ The merge-base of this and master is the commit from which the PR was branched. +# $sha1 +#+ If the patch merges cleanly, this is a reference to the merge commit hash +#+ (e.g. "origin/pr/2606/merge"). +#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit. +#+ The merge-base of this and master in the case of a clean merge is the most recent commit +#+ against master. + COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" @@ -84,42 +98,46 @@ function post_message () { fi } + +# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR +#+ and not anything else added to master since the PR was branched. + # check PR merge-ability and check for new public classes { if [ "$sha1" == "$ghprbActualCommit" ]; then - merge_note=" * This patch **does not** merge cleanly!" + merge_note=" * This patch **does not merge cleanly**." else merge_note=" * This patch merges cleanly." + fi + + source_files=$( + git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + | grep -v -e "\/test" `# ignore files in test directories` \ + | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ + | tr "\n" " " + ) + new_public_classes=$( + git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` + ) - source_files=$( - git diff master... --name-only `# diff patch against master from branch point` \ - | grep -v -e "\/test" `# ignore files in test directories` \ - | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ - | tr "\n" " " - ) - new_public_classes=$( - git diff master... ${source_files} `# diff patch against master from branch point` \ - | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ - | grep -e "trait " -e "class " `# filter in lines with these key words` \ - | grep -e "{" -e "(" `# filter in lines with these key words, too` \ - | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ - | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ - | tr -d "\n" `# remove actual LF characters` - ) - - if [ "$new_public_classes" == "" ]; then - public_classes_note=" * This patch adds no public classes." - else - public_classes_note=" * This patch adds the following public classes _(experimental)_:" - public_classes_note="${public_classes_note}\n${new_public_classes}" - fi + if [ -z "$new_public_classes" ]; then + public_classes_note=" * This patch adds no public classes." + else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + public_classes_note="${public_classes_note}\n${new_public_classes}" fi } @@ -147,12 +165,30 @@ function post_message () { post_message "$fail_message" exit $test_result + elif [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes all tests**." else - if [ "$test_result" -eq "0" ]; then - test_result_note=" * This patch **passes** unit tests." + if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then + failing_test="some tests" + elif [ "$test_result" -eq "$BLOCK_RAT" ]; then + failing_test="RAT tests" + elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then + failing_test="Scala style tests" + elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then + failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then + failing_test="to build" + elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then + failing_test="Spark unit tests" + elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then + failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then + failing_test="MiMa tests" else - test_result_note=" * This patch **fails** unit tests." + failing_test="some tests" fi + + test_result_note=" * This patch **fails $failing_test**." fi } From 70e824f750aa8ed446eec104ba158b0503ba58a9 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 7 Oct 2014 09:51:37 -0500 Subject: [PATCH 206/315] [SPARK-3627] - [yarn] - fix exit code and final status reporting to RM See the description and whats handled in the jira comment: https://issues.apache.org/jira/browse/SPARK-3627?focusedCommentId=14150013&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-14150013 This does not handle yarn client mode reporting of the driver to the AM. I think that should be handled when we make it an unmanaged AM. Author: Thomas Graves Closes #2577 from tgravescs/SPARK-3627 and squashes the following commits: 9c2efbf [Thomas Graves] review comments e8cc261 [Thomas Graves] fix accidental typo during fixing comment 24c98e3 [Thomas Graves] rework 85f1901 [Thomas Graves] Merge remote-tracking branch 'upstream/master' into SPARK-3627 fab166d [Thomas Graves] update based on review comments 32f4dfa [Thomas Graves] switch back f0b6519 [Thomas Graves] change order of cleanup staging dir d3cc800 [Thomas Graves] SPARK-3627 - yarn - fix exit code and final status reporting to RM --- .../spark/deploy/yarn/YarnRMClientImpl.scala | 26 +- .../spark/deploy/yarn/ApplicationMaster.scala | 295 +++++++++++------- .../spark/deploy/yarn/YarnRMClient.scala | 4 +- .../spark/deploy/yarn/YarnRMClientImpl.scala | 13 +- 4 files changed, 212 insertions(+), 126 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index 9bd1719cb1808..7faf55bc63372 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -40,6 +40,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC private var rpc: YarnRPC = null private var resourceManager: AMRMProtocol = _ private var uiHistoryAddress: String = _ + private var registered: Boolean = false override def register( conf: YarnConfiguration, @@ -51,8 +52,11 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC this.rpc = YarnRPC.create(conf) this.uiHistoryAddress = uiHistoryAddress - resourceManager = registerWithResourceManager(conf) - registerApplicationMaster(uiAddress) + synchronized { + resourceManager = registerWithResourceManager(conf) + registerApplicationMaster(uiAddress) + registered = true + } new YarnAllocationHandler(conf, sparkConf, resourceManager, getAttemptId(), args, preferredNodeLocations, securityMgr) @@ -66,14 +70,16 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC appAttemptId } - override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = { - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(getAttemptId()) - finishReq.setFinishApplicationStatus(status) - finishReq.setDiagnostics(diagnostics) - finishReq.setTrackingUrl(uiHistoryAddress) - resourceManager.finishApplicationMaster(finishReq) + override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized { + if (registered) { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(getAttemptId()) + finishReq.setFinishApplicationStatus(status) + finishReq.setDiagnostics(diagnostics) + finishReq.setTrackingUrl(uiHistoryAddress) + resourceManager.finishApplicationMaster(finishReq) + } } override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index caceef5d4b5b0..a3c43b43848d2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -56,8 +57,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + @volatile private var exitCode = 0 + @volatile private var unregistered = false @volatile private var finished = false @volatile private var finalStatus = FinalApplicationStatus.UNDEFINED + @volatile private var finalMsg: String = "" @volatile private var userClassThread: Thread = _ private var reporterThread: Thread = _ @@ -71,80 +75,107 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private val sparkContextRef = new AtomicReference[SparkContext](null) final def run(): Int = { - val appAttemptId = client.getAttemptId() + try { + val appAttemptId = client.getAttemptId() - if (isDriver) { - // Set the web ui port to be ephemeral for yarn so we don't conflict with - // other spark processes running on the same box - System.setProperty("spark.ui.port", "0") + if (isDriver) { + // Set the web ui port to be ephemeral for yarn so we don't conflict with + // other spark processes running on the same box + System.setProperty("spark.ui.port", "0") - // Set the master property to match the requested mode. - System.setProperty("spark.master", "yarn-cluster") + // Set the master property to match the requested mode. + System.setProperty("spark.master", "yarn-cluster") - // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. - System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) - } + // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + } - logInfo("ApplicationAttemptId: " + appAttemptId) + logInfo("ApplicationAttemptId: " + appAttemptId) - val cleanupHook = new Runnable { - override def run() { - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - finish(FinalApplicationStatus.SUCCEEDED) - } + val cleanupHook = new Runnable { + override def run() { + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + } + val maxAppAttempts = client.getMaxRegAttempts(yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + + if (!finished) { + // this shouldn't ever happen, but if it does assume weird failure + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, + "shutdown hook called without cleanly finishing") + } - // Cleanup the staging dir after the app is finished, or if it's the last attempt at - // running the AM. - val maxAppAttempts = client.getMaxRegAttempts(yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts - if (finished || isLastAttempt) { - cleanupStagingDir() + if (!unregistered) { + // we only want to unregister if we don't want the RM to retry + if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + unregister(finalStatus, finalMsg) + cleanupStagingDir() + } + } } } - } - // Use higher priority than FileSystem. - assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) - ShutdownHookManager - .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) + // Use higher priority than FileSystem. + assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) + ShutdownHookManager + .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserClass which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) + // Call this to force generation of secret so it gets populated into the + // Hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the executor containers. + val securityMgr = new SecurityManager(sparkConf) - if (isDriver) { - runDriver(securityMgr) - } else { - runExecutorLauncher(securityMgr) + if (isDriver) { + runDriver(securityMgr) + } else { + runExecutorLauncher(securityMgr) + } + } catch { + case e: Exception => + // catch everything else if not specifically handled + logError("Uncaught exception: ", e) + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, + "Uncaught exception: " + e.getMessage()) } + exitCode + } - if (finalStatus != FinalApplicationStatus.UNDEFINED) { - finish(finalStatus) - 0 - } else { - 1 + /** + * unregister is used to completely unregister the application from the ResourceManager. + * This means the ResourceManager will not retry the application attempt on your behalf if + * a failure occurred. + */ + final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) } } - final def finish(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { if (!finished) { - logInfo(s"Finishing ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - finished = true + logInfo(s"Final app status: ${status}, exitCode: ${code}" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code finalStatus = status - try { - if (Thread.currentThread() != reporterThread) { - reporterThread.interrupt() - reporterThread.join() - } - } finally { - client.shutdown(status, Option(diagnostics).getOrElse("")) + finalMsg = msg + finished = true + if (Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() } } } @@ -182,7 +213,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter() - val userThread = startUserClass() + setupSystemSecurityManager() + userClassThread = startUserClass() // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. @@ -190,15 +222,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // If there is no SparkContext at this point, just fail the app. if (sc == null) { - finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.") + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SC_NOT_INITED, + "Timed out waiting for SparkContext.") } else { registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) - try { - userThread.join() - } finally { - // In cluster mode, ask the reporter thread to stop since the user app is finished. - reporterThread.interrupt() - } + userClassThread.join() } } @@ -211,7 +240,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // In client mode the actor will stop the reporter thread. reporterThread.join() - finalStatus = FinalApplicationStatus.SUCCEEDED } private def launchReporterThread(): Thread = { @@ -231,33 +259,26 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val t = new Thread { override def run() { var failureCount = 0 - while (!finished) { try { - checkNumExecutorsFailed() - if (!finished) { + if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Max number of executor failures reached") + } else { logDebug("Sending progress") allocator.allocateResources() } failureCount = 0 } catch { + case i: InterruptedException => case e: Throwable => { failureCount += 1 if (!NonFatal(e) || failureCount >= reporterMaxFailures) { - logError("Exception was thrown from Reporter thread.", e) - finish(FinalApplicationStatus.FAILED, "Exception was thrown" + - s"${failureCount} time(s) from Reporter thread.") - - /** - * If exception is thrown from ReporterThread, - * interrupt user class to stop. - * Without this interrupting, if exception is - * thrown before allocating enough executors, - * YarnClusterScheduler waits until timeout even though - * we cannot allocate executors. - */ - logInfo("Interrupting user class to stop.") - userClassThread.interrupt + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + + s"${failureCount} time(s) from Reporter thread.") + } else { logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) } @@ -308,7 +329,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, sparkContextRef.synchronized { var count = 0 val waitTime = 10000L - val numTries = sparkConf.getInt("spark.yarn.ApplicationMaster.waitTries", 10) + val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 10) while (sparkContextRef.get() == null && count < numTries && !finished) { logInfo("Waiting for spark context initialization ... " + count) count = count + 1 @@ -328,10 +349,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def waitForSparkDriver(): ActorRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false + var count = 0 val hostport = args.userArgs(0) val (driverHost, driverPort) = Utils.parseHostPort(hostport) - while (!driverUp) { + + // spark driver should already be up since it launched us, but we don't want to + // wait forever, so wait 100 seconds max to match the cluster mode setting. + // Leave this config unpublished for now. SPARK-3779 to investigating changing + // this config to be time based. + val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 1000) + + while (!driverUp && !finished && count < numTries) { try { + count = count + 1 val socket = new Socket(driverHost, driverPort) socket.close() logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) @@ -343,6 +373,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, Thread.sleep(100) } } + + if (!driverUp) { + throw new SparkException("Failed to connect to driver!") + } + sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) @@ -354,18 +389,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") } - private def checkNumExecutorsFailed() = { - if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finish(FinalApplicationStatus.FAILED, "Max number of executor failures reached.") - - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from checkNumExecutorsFailed") - sc.stop() - } - } - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter() = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) @@ -379,40 +402,81 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } + /** + * This system security manager applies to the entire process. + * It's main purpose is to handle the case if the user code does a System.exit. + * This allows us to catch that and properly set the YARN application status and + * cleanup if needed. + */ + private def setupSystemSecurityManager(): Unit = { + try { + var stopped = false + System.setSecurityManager(new java.lang.SecurityManager() { + override def checkExit(paramInt: Int) { + if (!stopped) { + logInfo("In securityManager checkExit, exit code: " + paramInt) + if (paramInt == 0) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } else { + finish(FinalApplicationStatus.FAILED, + paramInt, + "User class exited with non-zero exit code") + } + stopped = true + } + } + // required for the checkExit to work properly + override def checkPermission(perm: java.security.Permission): Unit = {} + }) + } + catch { + case e: SecurityException => + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SECURITY, + "Error in setSecurityManager") + logError("Error in setSecurityManager:", e) + } + } + + /** + * Start the user class, which contains the spark driver, in a separate Thread. + * If the main routine exits cleanly or exits with System.exit(0) we + * assume it was successful, for all other cases we assume failure. + * + * Returns the user thread that was started. + */ private def startUserClass(): Thread = { logInfo("Starting the user JAR in a separate Thread") System.setProperty("spark.executor.instances", args.numExecutors.toString) val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) - userClassThread = new Thread { + val userThread = new Thread { override def run() { - var status = FinalApplicationStatus.FAILED try { - // Copy val mainArgs = new Array[String](args.userArgs.size) args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) mainMethod.invoke(null, mainArgs) - // Some apps have "System.exit(0)" at the end. The user thread will stop here unless - // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. - status = FinalApplicationStatus.SUCCEEDED + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running users class") } catch { case e: InvocationTargetException => e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class - - case e => throw e + case e: Exception => + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, + "User class threw exception: " + e.getMessage) + // re-throw to get it logged + throw e } - } finally { - logDebug("Finishing main") - finalStatus = status } } } - userClassThread.setName("Driver") - userClassThread.start() - userClassThread + userThread.setName("Driver") + userThread.start() + userThread } // Actor used to monitor the driver when running in client deploy mode. @@ -432,7 +496,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, override def receive = { case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") - finish(FinalApplicationStatus.SUCCEEDED) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") driver ! x @@ -446,6 +510,15 @@ object ApplicationMaster extends Logging { val SHUTDOWN_HOOK_PRIORITY: Int = 30 + // exit codes for different causes, no reason behind the values + private val EXIT_SUCCESS = 0 + private val EXIT_UNCAUGHT_EXCEPTION = 10 + private val EXIT_MAX_EXECUTOR_FAILURES = 11 + private val EXIT_REPORTER_FAILURE = 12 + private val EXIT_SC_NOT_INITED = 13 + private val EXIT_SECURITY = 14 + private val EXIT_EXCEPTION_USER_CLASS = 15 + private var master: ApplicationMaster = _ def main(args: Array[String]) = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 943dc56202a37..2510b9c9cef68 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -49,12 +49,12 @@ trait YarnRMClient { securityMgr: SecurityManager): YarnAllocator /** - * Shuts down the AM. Guaranteed to only be called once. + * Unregister the AM. Guaranteed to only be called once. * * @param status The final status of the AM. * @param diagnostics Diagnostics message to include in the final status. */ - def shutdown(status: FinalApplicationStatus, diagnostics: String = ""): Unit + def unregister(status: FinalApplicationStatus, diagnostics: String = ""): Unit /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index b581790e158ac..8d4b96ed79933 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -45,6 +45,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC private var amClient: AMRMClient[ContainerRequest] = _ private var uiHistoryAddress: String = _ + private var registered: Boolean = false override def register( conf: YarnConfiguration, @@ -59,13 +60,19 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC this.uiHistoryAddress = uiHistoryAddress logInfo("Registering the ApplicationMaster") - amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + synchronized { + amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + registered = true + } new YarnAllocationHandler(conf, sparkConf, amClient, getAttemptId(), args, preferredNodeLocations, securityMgr) } - override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = - amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) + override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized { + if (registered) { + amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) + } + } override def getAttemptId() = { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) From d65fd554b4de1dbd8db3090b0e50994010d30e78 Mon Sep 17 00:00:00 2001 From: Hossein Date: Tue, 7 Oct 2014 11:46:26 -0700 Subject: [PATCH 207/315] [SPARK-3827] Very long RDD names are not rendered properly in web UI With Spark SQL we generate very long RDD names. These names are not properly rendered in the web UI. This PR fixes the rendering issue. [SPARK-3827] #comment Linking PR with JIRA Author: Hossein Closes #2687 from falaki/sparkTableUI and squashes the following commits: fd06409 [Hossein] Limit width of cell when RDD name is too long --- core/src/main/resources/org/apache/spark/ui/static/webui.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 445110d63e184..152bde5f6994f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -51,6 +51,11 @@ table.sortable thead { cursor: pointer; } +table.sortable td { + word-wrap: break-word; + max-width: 600px; +} + .progress { margin-bottom: 0px; position: relative } From 12e2551ea1773ae19559ecdada35d23608e6b0ec Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 7 Oct 2014 11:53:22 -0700 Subject: [PATCH 208/315] [SPARK-3808] PySpark fails to start in Windows Modified syntax error of *.cmd script. Author: Masayoshi TSUZUKI Closes #2669 from tsudukim/feature/SPARK-3808 and squashes the following commits: 7f804e6 [Masayoshi TSUZUKI] [SPARK-3808] PySpark fails to start in Windows --- bin/compute-classpath.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 9b9e40321ea93..3cd0579aea8d3 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -38,7 +38,7 @@ if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" rem Build up classpath set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% -if "x%SPARK_CONF_DIR%"!="x" ( +if not "x%SPARK_CONF_DIR%"=="x" ( set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% ) else ( set CLASSPATH=%CLASSPATH%;%FWDIR%conf From 655032965fc7e2368dff9947fc024ac720ffd19c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 7 Oct 2014 12:06:12 -0700 Subject: [PATCH 209/315] [SPARK-3762] clear reference of SparkEnv after stop SparkEnv is cached in ThreadLocal object, so after stop and create a new SparkContext, old SparkEnv is still used by some threads, it will trigger many problems, for example, pyspark will have problem after restart SparkContext, because py4j use thread pool for RPC. This patch will clear all the references after stop a SparkEnv. cc mateiz tdas pwendell Author: Davies Liu Closes #2624 from davies/env and squashes the following commits: a69f30c [Davies Liu] deprecate getThreadLocal ba77ca4 [Davies Liu] remove getThreadLocal(), update docs ee62bb7 [Davies Liu] cleanup ThreadLocal of SparnENV 4d0ea8b [Davies Liu] clear reference of SparkEnv after stop --- .../scala/org/apache/spark/SparkEnv.scala | 19 ++++++++----------- .../apache/spark/api/python/PythonRDD.scala | 1 - .../org/apache/spark/executor/Executor.scala | 2 -- .../scala/org/apache/spark/rdd/PipedRDD.scala | 1 - .../apache/spark/scheduler/DAGScheduler.scala | 1 - .../spark/scheduler/TaskSchedulerImpl.scala | 2 -- .../streaming/scheduler/JobGenerator.scala | 1 - .../streaming/scheduler/JobScheduler.scala | 1 - .../streaming/scheduler/ReceiverTracker.scala | 1 - 9 files changed, 8 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72cac42cd2b2b..aba713cb4267a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -43,9 +43,8 @@ import org.apache.spark.util.{AkkaUtils, Utils} * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently - * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these - * objects needs to have the right SparkEnv set. You can get the current environment with - * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. + * Spark code finds the SparkEnv through a global variable, so all the threads can access the same + * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext). * * NOTE: This is not intended for external use. This is exposed for Shark and may be made private * in a future release. @@ -119,30 +118,28 @@ class SparkEnv ( } object SparkEnv extends Logging { - private val env = new ThreadLocal[SparkEnv] - @volatile private var lastSetSparkEnv : SparkEnv = _ + @volatile private var env: SparkEnv = _ private[spark] val driverActorSystemName = "sparkDriver" private[spark] val executorActorSystemName = "sparkExecutor" def set(e: SparkEnv) { - lastSetSparkEnv = e - env.set(e) + env = e } /** - * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv - * previously set in any thread. + * Returns the SparkEnv. */ def get: SparkEnv = { - Option(env.get()).getOrElse(lastSetSparkEnv) + env } /** * Returns the ThreadLocal SparkEnv. */ + @deprecated("Use SparkEnv.get instead", "1.2") def getThreadLocal: SparkEnv = { - env.get() + env } private[spark] def create( diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 924141475383d..ad6eb9ef50277 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -196,7 +196,6 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { - SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9bbfcdc4a0b6e..616c7e6a46368 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -148,7 +148,6 @@ private[spark] class Executor( override def run() { val startTime = System.currentTimeMillis() - SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -158,7 +157,6 @@ private[spark] class Executor( val startGCTime = gcTime try { - SparkEnv.set(env) Accumulators.clear() val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 5d77d37378458..56ac7a69be0d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { - SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) // input the pipe context firstly diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8135cdbb4c31f..788eb1ff4e455 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -630,7 +630,6 @@ class DAGScheduler( protected def runLocallyWithinThread(job: ActiveJob) { var jobResult: JobResult = JobSucceeded try { - SparkEnv.set(env) val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 4dc550413c13c..6d697e3d003f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -216,8 +216,6 @@ private[spark] class TaskSchedulerImpl( * that tasks are balanced across the cluster. */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - // Mark each slave as alive and remember its hostname // Also track if new executor is added var newExecAvail = false diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 374848358e700..7d73ada12d107 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -217,7 +217,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - SparkEnv.set(ssc.env) Try(graph.generateJobs(time)) match { case Success(jobs) => val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1b034b9fb187c..cfa3cd8925c80 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -138,7 +138,6 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) - SparkEnv.set(ssc.env) } private def handleJobCompletion(job: Job) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 5307fe189d717..7149dbc12a365 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -202,7 +202,6 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { @transient val thread = new Thread() { override def run() { try { - SparkEnv.set(env) startReceivers() } catch { case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") From bc87cc410fae59660c13b6ae1c14204df77237b8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 7 Oct 2014 12:20:12 -0700 Subject: [PATCH 210/315] [SPARK-3731] [PySpark] fix memory leak in PythonRDD The parent.getOrCompute() of PythonRDD is executed in a separated thread, it should release the memory reserved for shuffle and unrolling finally. Author: Davies Liu Closes #2668 from davies/leak and squashes the following commits: ae98be2 [Davies Liu] fix memory leak in PythonRDD --- .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ad6eb9ef50277..c74f86548ef85 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -247,6 +247,11 @@ private[spark] class PythonRDD( // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e worker.shutdownOutput() + } finally { + // Release memory used by this thread for shuffles + env.shuffleMemoryManager.releaseMemoryForThisThread() + // Release memory used by this thread for unrolling blocks + env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } From 553737c6e6d5ffa3b52a9888444f4beece5c5b1a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 7 Oct 2014 12:52:10 -0700 Subject: [PATCH 211/315] [SPARK-3825] Log more detail when unrolling a block fails Before: ``` 14/10/06 16:45:42 WARN CacheManager: Not enough space to cache partition rdd_0_2 in memory! Free memory is 481861527 bytes. ``` After: ``` 14/10/07 11:08:24 WARN MemoryStore: Not enough space to cache rdd_2_0 in memory! (computed 68.8 MB so far) 14/10/07 11:08:24 INFO MemoryStore: Memory use = 1088.0 B (blocks) + 445.1 MB (scratch space shared across 8 thread(s)) = 445.1 MB. Storage limit = 459.5 MB. ``` Author: Andrew Or Closes #2688 from andrewor14/cache-log-message and squashes the following commits: 28e33d6 [Andrew Or] Shy away from "unrolling" 5638c49 [Andrew Or] Grammar 39a0c28 [Andrew Or] Log more detail when unrolling a block fails --- .../scala/org/apache/spark/CacheManager.scala | 2 - .../apache/spark/storage/MemoryStore.scala | 45 ++++++++++++++++--- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index f8584b90cabe6..d89bb50076c9a 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -168,8 +168,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { arr.iterator.asInstanceOf[Iterator[T]] case Right(it) => // There is not enough space to cache this partition in memory - logWarning(s"Not enough space to cache partition $key in memory! " + - s"Free memory is ${blockManager.memoryStore.freeMemory} bytes.") val returnValues = it.asInstanceOf[Iterator[T]] if (putLevel.useDisk) { logWarning(s"Persisting partition $key to disk instead.") diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 0a09c24d61879..edbc729c17ade 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -132,8 +132,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) PutResult(res.size, res.data, droppedBlocks) case Right(iteratorValues) => // Not enough space to unroll this block; drop to disk if applicable - logWarning(s"Not enough space to store block $blockId in memory! " + - s"Free memory is $freeMemory bytes.") if (level.useDisk && allowPersistToDisk) { logWarning(s"Persisting block $blockId to disk instead.") val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues) @@ -265,6 +263,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) Left(vector.toArray) } else { // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, vector.estimateSize()) Right(vector.iterator ++ values) } @@ -424,7 +423,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Reserve additional memory for unrolling blocks used by this thread. * Return whether the request is granted. */ - private[spark] def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { @@ -439,7 +438,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Release memory used by this thread for unrolling blocks. * If the amount is not specified, remove the current thread's allocation altogether. */ - private[spark] def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { + def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { val threadId = Thread.currentThread().getId accountingLock.synchronized { if (memory < 0) { @@ -457,16 +456,50 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) /** * Return the amount of memory currently occupied for unrolling blocks across all threads. */ - private[spark] def currentUnrollMemory: Long = accountingLock.synchronized { + def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum } /** * Return the amount of memory currently occupied for unrolling blocks by this thread. */ - private[spark] def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { + def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) } + + /** + * Return the number of threads currently unrolling blocks. + */ + def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + + /** + * Log information about current memory usage. + */ + def logMemoryUsage(): Unit = { + val blocksMemory = currentMemory + val unrollMemory = currentUnrollMemory + val totalMemory = blocksMemory + unrollMemory + logInfo( + s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + + s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"Storage limit = ${Utils.bytesToString(maxMemory)}." + ) + } + + /** + * Log a warning for failing to unroll a block. + * + * @param blockId ID of the block we are trying to unroll. + * @param finalVectorSize Final size of the vector before unrolling failed. + */ + def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { + logWarning( + s"Not enough space to cache $blockId in memory! " + + s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" + ) + logMemoryUsage() + } } private[spark] case class ResultWithDroppedBlocks( From 446063eca98ae56d1ac61415f4c6e89699b8db02 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 7 Oct 2014 16:00:22 -0700 Subject: [PATCH 212/315] [SPARK-3777] Display "Executor ID" for Tasks in Stage page Now the Stage page only displays "Executor"(host) for tasks. However, there may be more than one Executors running in the same host. Currently, when some task is hung, I only know the host of the faulty executor. Therefore I have to check all executors in the host. Adding "Executor ID" in the Tasks table. would be helpful to locate the faulty executor. Here is the new page: ![add_executor_id_for_tasks](https://cloud.githubusercontent.com/assets/1000778/4505774/acb9648c-4afa-11e4-8826-8768a0a60cc9.png) Author: zsxwing Closes #2642 from zsxwing/SPARK-3777 and squashes the following commits: 37945af [zsxwing] Put Executor ID and Host into one cell 4bbe2c7 [zsxwing] [SPARK-3777] Display "Executor ID" for Tasks in Stage page --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index db01be596e073..2414e4c65237e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -103,7 +103,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val taskHeaders: Seq[String] = Seq( - "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", + "Index", "ID", "Attempt", "Status", "Locality Level", "Executor ID / Host", "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { }
    {info.status} {info.taskLocality}{info.host}{info.executorId} / {info.host} {UIUtils.formatDate(new Date(info.launchTime))} {formatDuration} From 3d7b36e0de26049e8b36b6705d8ff4224bde9eb1 Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Tue, 7 Oct 2014 16:40:16 -0700 Subject: [PATCH 213/315] [SPARK-3790][MLlib] CosineSimilarity Example Provide example for `RowMatrix.columnSimilarity()` Author: Reza Zadeh Closes #2622 from rezazadeh/dimsumexample and squashes the following commits: 8f20b82 [Reza Zadeh] update comment 379066d [Reza Zadeh] cache rows 792b81c [Reza Zadeh] Address review comments e573c7a [Reza Zadeh] Average absolute error b15685f [Reza Zadeh] Use scopt. Distribute evaluation. eca3dfd [Reza Zadeh] Documentation ac96fb2 [Reza Zadeh] Compute approximation error, add command line. 4533579 [Reza Zadeh] CosineSimilarity Example --- .../examples/mllib/CosineSimilarity.scala | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala new file mode 100644 index 0000000000000..6a3b0241ced7f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Compute the similar columns of a matrix, using cosine similarity. + * + * The input matrix must be stored in row-oriented dense format, one line per row with its entries + * separated by space. For example, + * {{{ + * 0.5 1.0 + * 2.0 3.0 + * 4.0 5.0 + * }}} + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). + * + * Example invocation: + * + * bin/run-example mllib.CosineSimilarity \ + * --threshold 0.1 data/mllib/sample_svm_data.txt + */ +object CosineSimilarity { + case class Params(inputFile: String = null, threshold: Double = 0.1) + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("CosineSimilarity") { + head("CosineSimilarity: an example app.") + opt[Double]("threshold") + .required() + .text(s"threshold similarity: to tradeoff computation vs quality estimate") + .action((x, c) => c.copy(threshold = x)) + arg[String]("") + .required() + .text(s"input file, one row per line, space-separated") + .action((x, c) => c.copy(inputFile = x)) + note( + """ + |For example, the following command runs this app on a dataset: + | + | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \ + | examplesjar.jar \ + | --threshold 0.1 data/mllib/sample_svm_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName("CosineSimilarity") + val sc = new SparkContext(conf) + + // Load and parse the data file. + val rows = sc.textFile(params.inputFile).map { line => + val values = line.split(' ').map(_.toDouble) + Vectors.dense(values) + }.cache() + val mat = new RowMatrix(rows) + + // Compute similar columns perfectly, with brute force. + val exact = mat.columnSimilarities() + + // Compute similar columns with estimation using DIMSUM + val approx = mat.columnSimilarities(params.threshold) + + val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) } + val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) } + val MAE = exactEntries.leftOuterJoin(approxEntries).values.map { + case (u, Some(v)) => + math.abs(u - v) + case (u, None) => + math.abs(u) + }.mean() + + println(s"Average absolute error in estimate is: $MAE") + + sc.stop() + } +} From 098c7344e64e69dffdcf0d95fe1c9e65a54e98f3 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Tue, 7 Oct 2014 16:43:34 -0700 Subject: [PATCH 214/315] [SPARK-3486][MLlib][PySpark] PySpark support for Word2Vec mengxr Added PySpark support for Word2Vec Change list (1) PySpark support for Word2Vec (2) SerDe support of string sequence both on python side and JVM side (3) Test for SerDe of string sequence on JVM side Author: Liquan Pei Closes #2356 from Ishiihara/Word2Vec-python and squashes the following commits: 476ea34 [Liquan Pei] style fixes b13a0b9 [Liquan Pei] resolve merge conflicts and minor fixes 8671eba [Liquan Pei] Merge remote-tracking branch 'upstream/master' into Word2Vec-python daf88a6 [Liquan Pei] modification according to feedback a73fa19 [Liquan Pei] clean up 3d8007b [Liquan Pei] fix findSynonyms for vector 1bdcd2e [Liquan Pei] minor fixes cdef9f4 [Liquan Pei] add missing comments b7447eb [Liquan Pei] modify according to feedback b9a7383 [Liquan Pei] cache words RDD in fit 89490bf [Liquan Pei] add tests and Word2VecModelWrapper 78bbb53 [Liquan Pei] use pickle for seq string SerDe a264b08 [Liquan Pei] Merge remote-tracking branch 'upstream/master' into Word2Vec-python ca1e5ff [Liquan Pei] fix test 68e7276 [Liquan Pei] minor style fixes 48d5e72 [Liquan Pei] Functionality improvement 0ad3ac1 [Liquan Pei] minor fix c867fdf [Liquan Pei] add Word2Vec to pyspark --- .../mllib/api/python/PythonMLLibAPI.scala | 57 +++++- .../apache/spark/mllib/feature/Word2Vec.scala | 12 +- python/docs/pyspark.mllib.rst | 8 + python/pyspark/mllib/feature.py | 193 ++++++++++++++++++ python/run-tests | 1 + 5 files changed, 264 insertions(+), 7 deletions(-) create mode 100644 python/pyspark/mllib/feature.py diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e9f41758581e3..f7251e65e04f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.Word2VecModel import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -42,9 +44,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. @@ -287,6 +289,59 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for Python mllib Word2Vec fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + * @param dataJRDD input JavaRDD + * @param vectorSize size of vector + * @param learningRate initial learning rate + * @param numPartitions number of partitions + * @param numIterations number of iterations + * @param seed initial seed for random generator + * @return A handle to java Word2VecModelWrapper instance at python side + */ + def trainWord2Vec( + dataJRDD: JavaRDD[java.util.ArrayList[String]], + vectorSize: Int, + learningRate: Double, + numPartitions: Int, + numIterations: Int, + seed: Long): Word2VecModelWrapper = { + val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) + val word2vec = new Word2Vec() + .setVectorSize(vectorSize) + .setLearningRate(learningRate) + .setNumPartitions(numPartitions) + .setNumIterations(numIterations) + .setSeed(seed) + val model = word2vec.fit(data) + data.unpersist() + new Word2VecModelWrapper(model) + } + + private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(words) + ret.add(similarity) + ret + } + } + /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index fc1444705364a..d321994c2a651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -67,7 +67,7 @@ private case class VocabWord( class Word2Vec extends Serializable with Logging { private var vectorSize = 100 - private var startingAlpha = 0.025 + private var learningRate = 0.025 private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() @@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging { * Sets initial learning rate (default: 0.025). */ def setLearningRate(learningRate: Double): this.type = { - this.startingAlpha = learningRate + this.learningRate = learningRate this } @@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha + var alpha = learningRate for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging { lwc = wordCount // TODO: discount by iteration? alpha = - startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) - if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } wc += sentence.size @@ -437,7 +437,7 @@ class Word2VecModel private[mllib] ( * Find synonyms of a word * @param word a word * @param num number of synonyms to find - * @return array of (word, similarity) + * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index e95d19e97f151..4548b8739ed91 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -20,6 +20,14 @@ pyspark.mllib.clustering module :undoc-members: :show-inheritance: +pyspark.mllib.feature module +------------------------------- + +.. automodule:: pyspark.mllib.feature + :members: + :undoc-members: + :show-inheritance: + pyspark.mllib.linalg module --------------------------- diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py new file mode 100644 index 0000000000000..a44a27fd3b6a6 --- /dev/null +++ b/python/pyspark/mllib/feature.py @@ -0,0 +1,193 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for feature in MLlib. +""" +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + +from pyspark.mllib.linalg import _convert_to_vector + +__all__ = ['Word2Vec', 'Word2VecModel'] + + +class Word2VecModel(object): + """ + class for Word2Vec model + """ + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def transform(self, word): + """ + :param word: a word + :return: vector representation of word + Transforms a word to its vector representation + + Note: local use only + """ + # TODO: make transform usable in RDD operations from python side + result = self._java_model.transform(word) + return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result))) + + def findSynonyms(self, x, num): + """ + :param x: a word or a vector representation of word + :param num: number of synonyms to find + :return: array of (word, cosineSimilarity) + Find synonyms of a word + + Note: local use only + """ + # TODO: make findSynonyms usable in RDD operations from python side + ser = PickleSerializer() + if type(x) == str: + jlist = self._java_model.findSynonyms(x, num) + else: + bytes = bytearray(ser.dumps(_convert_to_vector(x))) + vec = self._sc._jvm.SerDe.loads(bytes) + jlist = self._java_model.findSynonyms(vec, num) + words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist))) + return zip(words, similarity) + + +class Word2Vec(object): + """ + Word2Vec creates vector representation of words in a text corpus. + The algorithm first constructs a vocabulary from the corpus + and then learns vector representation of words in the vocabulary. + The vector representation can be used as features in + natural language processing and machine learning algorithms. + + We used skip-gram model in our implementation and hierarchical softmax + method to train the model. The variable names in the implementation + matches the original C implementation. + For original C implementation, see https://code.google.com/p/word2vec/ + For research papers, see + Efficient Estimation of Word Representations in Vector Space + and + Distributed Representations of Words and Phrases and their Compositionality. + + >>> sentence = "a b " * 100 + "a c " * 10 + >>> localDoc = [sentence, sentence] + >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) + >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> syms = model.findSynonyms("a", 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + >>> vec = model.transform("a") + >>> len(vec) + 10 + >>> syms = model.findSynonyms(vec, 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + """ + def __init__(self): + """ + Construct Word2Vec instance + """ + self.vectorSize = 100 + self.learningRate = 0.025 + self.numPartitions = 1 + self.numIterations = 1 + self.seed = 42L + + def setVectorSize(self, vectorSize): + """ + Sets vector size (default: 100). + """ + self.vectorSize = vectorSize + return self + + def setLearningRate(self, learningRate): + """ + Sets initial learning rate (default: 0.025). + """ + self.learningRate = learningRate + return self + + def setNumPartitions(self, numPartitions): + """ + Sets number of partitions (default: 1). Use a small number for accuracy. + """ + self.numPartitions = numPartitions + return self + + def setNumIterations(self, numIterations): + """ + Sets number of iterations (default: 1), which should be smaller than or equal to number of + partitions. + """ + self.numIterations = numIterations + return self + + def setSeed(self, seed): + """ + Sets random seed. + """ + self.seed = seed + return self + + def fit(self, data): + """ + Computes the vector representation of each word in vocabulary. + + :param data: training data. RDD of subtype of Iterable[String] + :return: python Word2VecModel instance + """ + sc = data.context + ser = PickleSerializer() + vectorSize = self.vectorSize + learningRate = self.learningRate + numPartitions = self.numPartitions + numIterations = self.numIterations + seed = self.seed + + model = sc._jvm.PythonMLLibAPI().trainWord2Vec( + data._to_java_object_rdd(), vectorSize, + learningRate, numPartitions, numIterations, seed) + return Word2VecModel(sc, model) + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/run-tests b/python/run-tests index c713861eb77bb..63395f72788f9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -69,6 +69,7 @@ function run_mllib_tests() { echo "Run mllib tests ..." run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/feature.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" From b32bb72e812731d28bf05f2145314c63806f3335 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 7 Oct 2014 16:47:24 -0700 Subject: [PATCH 215/315] [SPARK-3832][MLlib] Upgrade Breeze dependency to 0.10 In Breeze 0.10, the L1regParam can be configured through anonymous function in OWLQN, and each component can be penalized differently. This is required for GLMNET in MLlib with L1/L2 regularization. https://github.com/scalanlp/breeze/commit/2570911026aa05aa1908ccf7370bc19cd8808a4c Author: DB Tsai Closes #2693 from dbtsai/breeze0.10 and squashes the following commits: 7a0c45c [DB Tsai] In Breeze 0.10, the L1regParam can be configured through anonymous function in OWLQN, and each component can be penalized differently. This is required for GLMNET in MLlib with L1/L2 regularization. https://github.com/scalanlp/breeze/commit/2570911026aa05aa1908ccf7370bc19cd8808a4c --- mllib/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index a5eeef88e9d62..696e9396f627c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,7 +57,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.9 + 0.10 From 5912ca67140eed5dea66745aa3af4febdbd80781 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Tue, 7 Oct 2014 16:54:32 -0700 Subject: [PATCH 216/315] [SPARK-3398] [EC2] Have spark-ec2 intelligently wait for specific cluster states Instead of waiting arbitrary amounts of time for the cluster to reach a specific state, this patch lets `spark-ec2` explicitly wait for a cluster to reach a desired state. This is useful in a couple of situations: * The cluster is launching and you want to wait until SSH is available before installing stuff. * The cluster is being terminated and you want to wait until all the instances are terminated before trying to delete security groups. This patch removes the need for the `--wait` option and removes some of the time-based retry logic that was being used. Author: Nicholas Chammas Closes #2339 from nchammas/spark-ec2-wait-properly and squashes the following commits: 43a69f0 [Nicholas Chammas] short-circuit SSH check; linear backoff 9a9e035 [Nicholas Chammas] remove extraneous comment 26c5ed0 [Nicholas Chammas] replace print with write() bb67c06 [Nicholas Chammas] deprecate wait option; remove dead code 7969265 [Nicholas Chammas] fix long line (PEP 8) 126e4cf [Nicholas Chammas] wait for specific cluster states --- ec2/spark_ec2.py | 111 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 86 insertions(+), 25 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 941dfb988b9fb..27f468ea4f395 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -32,6 +32,7 @@ import tempfile import time import urllib2 +import warnings from optparse import OptionParser from sys import stderr import boto @@ -61,8 +62,8 @@ def parse_args(): "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") parser.add_option( - "-w", "--wait", type="int", default=120, - help="Seconds to wait for nodes to start (default: %default)") + "-w", "--wait", type="int", + help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") parser.add_option( "-k", "--key-pair", help="Key pair to use on instances") @@ -195,18 +196,6 @@ def get_or_make_group(conn, name): return conn.create_security_group(name, "Spark EC2 group") -# Wait for a set of launched instances to exit the "pending" state -# (i.e. either to start running or to fail and be terminated) -def wait_for_instances(conn, instances): - while True: - for i in instances: - i.update() - if len([i for i in instances if i.state == 'pending']) > 0: - time.sleep(5) - else: - return - - # Check whether a given EC2 instance object is in a state we consider active, # i.e. not terminating or terminated. We count both stopping and stopped as # active since we can restart stopped clusters. @@ -619,14 +608,64 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up -def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): - print "Waiting for instances to start up..." - time.sleep(5) - wait_for_instances(conn, master_nodes) - wait_for_instances(conn, slave_nodes) - print "Waiting %d more seconds..." % wait_secs - time.sleep(wait_secs) +def is_ssh_available(host, opts): + "Checks if SSH is available on the host." + try: + with open(os.devnull, 'w') as devnull: + ret = subprocess.check_call( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=devnull, + stderr=devnull + ) + return ret == 0 + except subprocess.CalledProcessError as e: + return False + + +def is_cluster_ssh_available(cluster_instances, opts): + for i in cluster_instances: + if not is_ssh_available(host=i.ip_address, opts=opts): + return False + else: + return True + + +def wait_for_cluster_state(cluster_instances, cluster_state, opts): + """ + cluster_instances: a list of boto.ec2.instance.Instance + cluster_state: a string representing the desired state of all the instances in the cluster + value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as + 'running', 'terminated', etc. + (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) + """ + sys.stdout.write( + "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state) + ) + sys.stdout.flush() + + num_attempts = 0 + + while True: + time.sleep(3 * num_attempts) + + for i in cluster_instances: + s = i.update() # capture output to suppress print to screen in newer versions of boto + + if cluster_state == 'ssh-ready': + if all(i.state == 'running' for i in cluster_instances) and \ + is_cluster_ssh_available(cluster_instances, opts): + break + else: + if all(i.state == cluster_state for i in cluster_instances): + break + + num_attempts += 1 + + sys.stdout.write(".") + sys.stdout.flush() + + sys.stdout.write("\n") # Get number of local disks available for a given EC2 instance type. @@ -868,6 +907,16 @@ def real_main(): (opts, action, cluster_name) = parse_args() # Input parameter validation + if opts.wait is not None: + # NOTE: DeprecationWarnings are silent in 2.7+ by default. + # To show them, run Python with the -Wdefault switch. + # See: https://docs.python.org/3.5/whatsnew/2.7.html + warnings.warn( + "This option is deprecated and has no effect. " + "spark-ec2 automatically waits as long as necessary for clusters to startup.", + DeprecationWarning + ) + if opts.ebs_vol_num > 8: print >> stderr, "ebs-vol-num cannot be greater than 8" sys.exit(1) @@ -890,7 +939,11 @@ def real_main(): (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) else: (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": @@ -919,7 +972,11 @@ def real_main(): else: group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] - + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='terminated', + opts=opts + ) attempt = 1 while attempt <= 3: print "Attempt %d" % attempt @@ -1019,7 +1076,11 @@ def real_main(): for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: From b69c9fb6fb048509bbd8430fb697dc3a5ca4fe59 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 7 Oct 2014 16:54:49 -0700 Subject: [PATCH 217/315] [SPARK-3829] Make Spark logo image on the header of HistoryPage as a link to HistoryPage's page #1 There is a Spark logo on the header of HistoryPage. We can have too many HistoryPages if we run 20+ applications. So I think, it's useful if the logo is as a link to the HistoryPage's page number 1. Author: Kousuke Saruta Closes #2690 from sarutak/SPARK-3829 and squashes the following commits: 908c109 [Kousuke Saruta] Removed extra space. 00bfbd7 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3829 dd87480 [Kousuke Saruta] Made header Spark log image as a link to History Server's top page. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f0006b42aee4f..be69060fc3bf8 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -216,8 +216,10 @@ private[spark] object UIUtils extends Logging {

    - + + + {title}

    From 798ed22c289cf65f2249bf2f4250285685ca69e7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 7 Oct 2014 18:09:27 -0700 Subject: [PATCH 218/315] [SPARK-3412] [PySpark] Replace Epydoc with Sphinx to generate Python API docs Retire Epydoc, use Sphinx to generate API docs. Refine Sphinx docs, also convert some docstrings into Sphinx style. It looks like: ![api doc](https://cloud.githubusercontent.com/assets/40902/4538272/9e2d4f10-4dec-11e4-8d96-6e45a8fe51f9.png) Author: Davies Liu Closes #2689 from davies/docs and squashes the following commits: bf4a0a5 [Davies Liu] fix links 3fb1572 [Davies Liu] fix _static in jekyll 65a287e [Davies Liu] fix scripts and logo 8524042 [Davies Liu] Merge branch 'master' of github.com:apache/spark into docs d5b874a [Davies Liu] Merge branch 'master' of github.com:apache/spark into docs 4bc1c3c [Davies Liu] refactor 746d0b6 [Davies Liu] @param -> :param 240b393 [Davies Liu] replace epydoc with sphinx doc --- docs/README.md | 8 +-- docs/_config.yml | 3 + docs/_plugins/copy_api_dirs.rb | 19 +++--- python/docs/conf.py | 12 ++-- python/docs/index.rst | 6 +- python/epydoc.conf | 38 ----------- python/pyspark/__init__.py | 26 ++------ python/pyspark/conf.py | 8 +-- python/pyspark/context.py | 92 +++++++++++++------------- python/pyspark/mllib/classification.py | 32 ++++----- python/pyspark/mllib/linalg.py | 8 +-- python/pyspark/mllib/regression.py | 18 ++--- python/pyspark/mllib/util.py | 18 ++--- python/pyspark/rdd.py | 52 +++++++-------- python/pyspark/sql.py | 33 +++++---- 15 files changed, 167 insertions(+), 206 deletions(-) delete mode 100644 python/epydoc.conf diff --git a/docs/README.md b/docs/README.md index 79708c3df9106..0facecdd5f767 100644 --- a/docs/README.md +++ b/docs/README.md @@ -54,19 +54,19 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} -## API Docs (Scaladoc and Epydoc) +## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. -Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the -SPARK_PROJECT_ROOT/pyspark directory. Documentation is only generated for classes that are listed as +Similarly, you can build just the PySpark docs by running `make html` from the +SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the -PySpark docs using [epydoc](http://epydoc.sourceforge.net/). +PySpark docs [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_config.yml b/docs/_config.yml index 7bc3a78e2d265..f4bf242ac191b 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -8,6 +8,9 @@ gems: kramdown: entity_output: numeric +include: + - _static + # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. SPARK_VERSION: 1.0.0-SNAPSHOT diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3b02e090aec28..4566a2fff562b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,19 +63,20 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) - # Build Epydoc for Python - puts "Moving to python directory and building epydoc." - cd("../python") - puts `epydoc --config epydoc.conf` + # Build Sphinx docs for Python - puts "Moving back into docs dir." - cd("../docs") + puts "Moving to python/docs directory and building sphinx." + cd("../python/docs") + puts `make html` + + puts "Moving back into home dir." + cd("../../") puts "Making directory api/python" - mkdir_p "api/python" + mkdir_p "docs/api/python" - puts "cp -r ../python/docs/. api/python" - cp_r("../python/docs/.", "api/python") + puts "cp -r python/docs/_build/html/. docs/api/python" + cp_r("python/docs/_build/html/.", "docs/api/python") cd("..") end diff --git a/python/docs/conf.py b/python/docs/conf.py index c368cf81a003b..8e6324f058251 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.1' +version = '1.2-SNAPSHOT' # The full version, including alpha/beta/rc tags. -release = '' +release = '1.2-SNAPSHOT' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -102,7 +102,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = 'nature' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -121,7 +121,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = "../../docs/img/spark-logo-hd.png" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -154,10 +154,10 @@ #html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +html_domain_indices = False # If false, no index is generated. -#html_use_index = True +html_use_index = False # If true, the index is split into individual pages for each letter. #html_split_index = False diff --git a/python/docs/index.rst b/python/docs/index.rst index 25b3f9bd93e63..d66e051b15371 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -3,7 +3,7 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to PySpark API reference! +Welcome to Spark Python API Docs! =================================== Contents: @@ -24,14 +24,12 @@ Core classes: Main entry point for Spark functionality. :class:`pyspark.RDD` - + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Indices and tables ================== -* :ref:`genindex` -* :ref:`modindex` * :ref:`search` diff --git a/python/epydoc.conf b/python/epydoc.conf deleted file mode 100644 index 8593e08deda19..0000000000000 --- a/python/epydoc.conf +++ /dev/null @@ -1,38 +0,0 @@ -[epydoc] # Epydoc section marker (required by ConfigParser) - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Information about the project. -name: Spark 1.0.0 Python API Docs -url: http://spark.apache.org - -# The list of modules to document. Modules can be named using -# dotted names, module filenames, or package directory names. -# This option may be repeated. -modules: pyspark - -# Write html output to the directory "apidocs" -output: html -target: docs/ - -private: no - -exclude: pyspark.cloudpickle pyspark.worker pyspark.join - pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests - pyspark.rddsampler pyspark.daemon - pyspark.mllib.tests pyspark.shuffle diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 1a2e774738fe7..e39e6514d77a1 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -20,33 +20,21 @@ Public classes: - - L{SparkContext} + - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - L{Broadcast} A broadcast variable that gets reused across tasks. - - L{Accumulator} + - L{Accumulator} An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - L{SparkConf} For configuring Spark. - - L{SparkFiles} + - L{SparkFiles} Access files shipped with jobs. - - L{StorageLevel} + - L{StorageLevel} Finer-grained cache persistence levels. -Spark SQL: - - L{SQLContext} - Main entry point for SQL functionality. - - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. - - L{Row} - A Row of data returned by a Spark SQL query. - -Hive: - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. """ # The following block allows us to import python's random instead of mllib.random for scripts in diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b64875a3f495a..dc7cd0bce56f3 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -83,11 +83,11 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): """ Create a new Spark configuration. - @param loadDefaults: whether to load values from Java system + :param loadDefaults: whether to load values from Java system properties (True by default) - @param _jvm: internal parameter used to pass a handle to the + :param _jvm: internal parameter used to pass a handle to the Java VM; does not need to be set by users - @param _jconf: Optionally pass in an existing SparkConf handle + :param _jconf: Optionally pass in an existing SparkConf handle to use its parameters """ if _jconf: @@ -139,7 +139,7 @@ def setAll(self, pairs): """ Set multiple parameters, passed as a list of key-value pairs. - @param pairs: list of key-value pairs to set + :param pairs: list of key-value pairs to set """ for (k, v) in pairs: self._jconf.set(k, v) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a45d79d6424c7..6fb30d65c5edd 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -73,21 +73,21 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. - @param master: Cluster URL to connect to + :param master: Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - @param appName: A name for your job, to display on the cluster web UI. - @param sparkHome: Location where Spark is installed on cluster nodes. - @param pyFiles: Collection of .zip or .py files to send to the cluster + :param appName: A name for your job, to display on the cluster web UI. + :param sparkHome: Location where Spark is installed on cluster nodes. + :param pyFiles: Collection of .zip or .py files to send to the cluster and add to PYTHONPATH. These can be paths on the local file system or HDFS, HTTP, HTTPS, or FTP URLs. - @param environment: A dictionary of environment variables to set on + :param environment: A dictionary of environment variables to set on worker nodes. - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. - @param serializer: The serializer for RDDs. - @param conf: A L{SparkConf} object setting Spark properties. - @param gateway: Use an existing gateway and JVM, otherwise a new JVM + :param serializer: The serializer for RDDs. + :param conf: A L{SparkConf} object setting Spark properties. + :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. @@ -417,16 +417,16 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, 3. If this fails, the fallback is to call 'toString' on each key and value 4. C{PickleSerializer} is used to deserialize pickled objects on the Python side - @param path: path to sequncefile - @param keyClass: fully qualified classname of key Writable class + :param path: path to sequncefile + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: - @param valueConverter: - @param minSplits: minimum splits in dataset + :param keyConverter: + :param valueConverter: + :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ minSplits = minSplits or min(self.defaultParallelism, 2) @@ -446,18 +446,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -476,17 +476,17 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -507,18 +507,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java. - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -537,17 +537,17 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index a765b1c4f7d87..cd43982191702 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -79,15 +79,15 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. :Allowed values: @@ -151,15 +151,15 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, """ Train a support vector machine on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param regParam: The regularizer parameter (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param regParam: The regularizer parameter (default: 1.0). + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regType: The type of regularizer used for training our model. :Allowed values: @@ -238,10 +238,10 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - @param data: RDD of NumPy vectors, one per element, where the first + :param data: RDD of NumPy vectors, one per element, where the first coordinate is the label and the rest is the feature vector (e.g. a count vector). - @param lambda_: The smoothing parameter + :param lambda_: The smoothing parameter """ sc = data.context jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51014a8ceb785..24c5480b2f753 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -238,8 +238,8 @@ def __init__(self, size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print SparseVector(4, {1: 1.0, 3: 5.5}) @@ -458,8 +458,8 @@ def sparse(size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 54f34a98337ca..12b322aaae796 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -31,8 +31,8 @@ class LabeledPoint(object): """ The features and labels of a data point. - @param label: Label for this data point. - @param features: Vector of features for this point (NumPy array, list, + :param label: Label for this data point. + :param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ @@ -145,15 +145,15 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a linear regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. :Allowed values: diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 8233d4e81f1ca..1357fd4fbc8aa 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -77,10 +77,10 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None method parses each line into a LabeledPoint, where the feature indices are converted to zero-based. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param numFeatures: number of features, which will be determined + :param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is already split into multiple files and you @@ -88,7 +88,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None features may not present in certain files, which leads to inconsistent feature dimensions. - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile @@ -126,8 +126,8 @@ def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. - @param data: an RDD of LabeledPoint to be saved - @param dir: directory to save the data + :param data: an RDD of LabeledPoint to be saved + :param dir: directory to save the data >>> from tempfile import NamedTemporaryFile >>> from fileinput import input @@ -149,10 +149,10 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index e77669aad76b6..6797d50659a92 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -752,7 +752,7 @@ def max(self, key=None): """ Find the maximum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) >>> rdd.max() @@ -768,7 +768,7 @@ def min(self, key=None): """ Find the minimum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) >>> rdd.min() @@ -1115,9 +1115,9 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1135,16 +1135,16 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop job configuration, passed in as a dict (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1161,9 +1161,9 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1182,17 +1182,17 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: (None by default) - @param compressionCodecClass: (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: (None by default) + :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1212,8 +1212,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. 2. Keys and values of this Java RDD are converted to Writables and written out. - @param path: path to sequence file - @param compressionCodecClass: (None by default) + :param path: path to sequence file + :param compressionCodecClass: (None by default) """ pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) @@ -2009,7 +2009,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - @param relativeSD Relative accuracy. Smaller values create + :param relativeSD Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 114644ab8b79d..3d5a281239d66 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,28 +15,37 @@ # limitations under the License. # +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{SchemaRDD} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" -import sys -import types import itertools -import warnings import decimal import datetime import keyword import warnings from array import array from operator import itemgetter +from itertools import imap + +from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from itertools import chain, ifilter, imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -899,8 +908,8 @@ class SQLContext(object): def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. - @param sparkContext: The SparkContext to wrap. - @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) @@ -1325,8 +1334,8 @@ class HiveContext(SQLContext): def __init__(self, sparkContext, hiveContext=None): """Create a new HiveContext. - @param sparkContext: The SparkContext to wrap. - @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new HiveContext in the JVM, instead we make all calls to this object. """ SQLContext.__init__(self, sparkContext) From c7818434fa8ae8e02a0d66183990077a4ba1436c Mon Sep 17 00:00:00 2001 From: Ahir Reddy Date: Tue, 7 Oct 2014 22:32:39 -0700 Subject: [PATCH 219/315] [SPARK-3836] [REPL] Spark REPL optionally propagate internal exceptions Optionally have the repl throw exceptions generated by interpreted code, instead of swallowing the exception and returning it as text output. This is useful when embedding the repl, otherwise it's not possible to know when user code threw an exception. Author: Ahir Reddy Closes #2695 from ahirreddy/repl-throw-exceptions and squashes the following commits: bad25ee [Ahir Reddy] Style Fixes f0e5b44 [Ahir Reddy] Fixed style 0d4413d [Ahir Reddy] propogate excetions from repl --- .../scala/org/apache/spark/repl/SparkIMain.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 6ddb6accd696b..646c68e60c2e9 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -84,9 +84,11 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) - extends SparkImports with Logging { - imain => + class SparkIMain( + initialSettings: Settings, + val out: JPrintWriter, + propagateExceptions: Boolean = false) + extends SparkImports with Logging { imain => val conf = new SparkConf() @@ -816,6 +818,10 @@ import org.apache.spark.util.Utils val resultName = FixedSessionNames.resultName def bindError(t: Throwable) = { + // Immediately throw the exception if we are asked to propagate them + if (propagateExceptions) { + throw unwrap(t) + } if (!bindExceptions) // avoid looping if already binding throw t From 35afdfd624fe19ce0c009cf065bb6794ee68e181 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 7 Oct 2014 23:26:24 -0700 Subject: [PATCH 220/315] [SPARK-3710] Fix Yarn integration tests on Hadoop 2.2. It seems some dependencies are not declared when pulling the 2.2 test dependencies, so we need to add them manually for the Yarn cluster to come up. These don't seem to be necessary for 2.3 and beyond, so restrict them to the hadoop-2.2 profile. Author: Marcelo Vanzin Closes #2682 from vanzin/SPARK-3710 and squashes the following commits: 701d4fb [Marcelo Vanzin] Add comment. 0540bdf [Marcelo Vanzin] [SPARK-3710] Fix Yarn integration tests on Hadoop 2.2. --- yarn/stable/pom.xml | 51 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 97eb0548e77c3..fe55d70ccc370 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -41,4 +41,55 @@ + + + + hadoop-2.2 + + 1.9 + + + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + ${jersey.version} + test + + + com.sun.jersey + jersey-json + ${jersey.version} + test + + + stax + stax-api + + + + + com.sun.jersey + jersey-server + ${jersey.version} + test + + + + + From 7fca8f41c8889a41d9ab05ad0ab39c7639f657ed Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 8 Oct 2014 08:48:55 -0500 Subject: [PATCH 221/315] [SPARK-3788] [yarn] Fix compareFs to do the right thing for HDFS namespaces. HA and viewfs use namespaces instead of host names, so you can't resolve them since that will fail. So be smarter to avoid doing unnecessary work. Author: Marcelo Vanzin Closes #2649 from vanzin/SPARK-3788 and squashes the following commits: fedbc73 [Marcelo Vanzin] Update comment. c938845 [Marcelo Vanzin] Use Objects.equal() to avoid issues with ==. 9f7b571 [Marcelo Vanzin] [SPARK-3788] [yarn] Fix compareFs to do the right thing for HA, federation. --- .../apache/spark/deploy/yarn/ClientBase.scala | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6ecac6eae6e03..14a0386b78978 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} import scala.util.{Try, Success, Failure} +import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission @@ -64,12 +65,12 @@ private[spark] trait ClientBase extends Logging { s"memory capability of the cluster ($maxMem MB per container)") val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( @@ -771,15 +772,17 @@ private[spark] object ClientBase extends Logging { private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { val srcUri = srcFs.getUri() val dstUri = destFs.getUri() - if (srcUri.getScheme() == null) { - return false - } - if (!srcUri.getScheme().equals(dstUri.getScheme())) { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() - if ((srcHost != null) && (dstHost != null)) { + + // In HA or when using viewfs, the host part of the URI may not actually be a host, but the + // name of the HDFS namespace. Those names won't resolve, so avoid even trying if they + // match. + if (srcHost != null && dstHost != null && srcHost != dstHost) { try { srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() @@ -787,19 +790,9 @@ private[spark] object ClientBase extends Logging { case e: UnknownHostException => return false } - if (!srcHost.equals(dstHost)) { - return false - } - } else if (srcHost == null && dstHost != null) { - return false - } else if (srcHost != null && dstHost == null) { - return false - } - if (srcUri.getPort() != dstUri.getPort()) { - false - } else { - true } + + Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() } } From f18dd5962e4a18c3507de8147bde3a8f56380439 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 8 Oct 2014 11:53:43 -0500 Subject: [PATCH 222/315] [SPARK-3848] yarn alpha doesn't build on master yarn alpha build was broken by #2432 as it added an argument to YarnAllocator but not to yarn/alpha YarnAllocationHandler commit https://github.com/apache/spark/commit/79e45c9323455a51f25ed9acd0edd8682b4bbb88 Author: Kousuke Saruta Closes #2715 from sarutak/SPARK-3848 and squashes the following commits: bafb8d1 [Kousuke Saruta] Fixed parameters for the default constructor of alpha/YarnAllocatorHandler. --- .../org/apache/spark/deploy/yarn/YarnAllocationHandler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 6c93d8582330b..abd37834ed3cc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -43,7 +43,7 @@ private[yarn] class YarnAllocationHandler( args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) - extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { + extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) { private val lastResponseId = new AtomicInteger() private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() From bc4418727b40c9b6ba5194ead6e2698539272280 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 8 Oct 2014 13:33:46 -0700 Subject: [PATCH 223/315] HOTFIX: Use correct Hadoop profile in build --- dev/run-tests | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests b/dev/run-tests index 4be2baaf48cd1..f47fcf66ff7e7 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -42,7 +42,7 @@ function handle_error () { elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi From b92bd5a2f29f7a9ce270540b6a828fa7ff205cbe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 8 Oct 2014 14:23:21 -0700 Subject: [PATCH 224/315] [SPARK-3841] [mllib] Pretty-print params for ML examples Provide a parent class for the Params case classes used in many MLlib examples, where the parent class pretty-prints the case class fields: Param1Name Param1Value Param2Name Param2Value ... Using this class will make it easier to print test settings to logs. Also, updated DecisionTreeRunner to print a little more info. CC: mengxr Author: Joseph K. Bradley Closes #2700 from jkbradley/dtrunner-update and squashes the following commits: cff873f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update 7a08ae4 [Joseph K. Bradley] code review comment updates b4d2043 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update d8228a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update 0fc9c64 [Joseph K. Bradley] Added abstract TestParams class for mllib example parameters 12b7798 [Joseph K. Bradley] Added abstract class TestParams for pretty-printing Params values 5f84f03 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update f7441b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update 19eb6fc [Joseph K. Bradley] Updated DecisionTreeRunner to print training time. --- .../spark/examples/mllib/AbstractParams.scala | 53 +++++++++++++++++++ .../examples/mllib/BinaryClassification.scala | 2 +- .../spark/examples/mllib/Correlations.scala | 1 + .../examples/mllib/CosineSimilarity.scala | 1 + .../examples/mllib/DecisionTreeRunner.scala | 15 +++++- .../spark/examples/mllib/DenseKMeans.scala | 2 +- .../examples/mllib/LinearRegression.scala | 2 +- .../spark/examples/mllib/MovieLensALS.scala | 2 +- .../mllib/MultivariateSummarizer.scala | 1 + .../spark/examples/mllib/SampledRDDs.scala | 1 + .../examples/mllib/SparseNaiveBayes.scala | 2 +- 11 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala new file mode 100644 index 0000000000000..ae6057758d6fc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scala.reflect.runtime.universe._ + +/** + * Abstract class for parameter case classes. + * This overrides the [[toString]] method to print all case class fields by name and value. + * @tparam T Concrete parameter class. + */ +abstract class AbstractParams[T: TypeTag] { + + private def tag: TypeTag[T] = typeTag[T] + + /** + * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: + * { + * [field name]:\t[field value]\n + * [field name]:\t[field value]\n + * ... + * } + */ + override def toString: String = { + val tpe = tag.tpe + val allAccessors = tpe.declarations.collect { + case m: MethodSymbol if m.isCaseAccessor => m + } + val mirror = runtimeMirror(getClass.getClassLoader) + val instanceMirror = mirror.reflect(this) + allAccessors.map { f => + val paramName = f.name.toString + val fieldMirror = instanceMirror.reflectField(f) + val paramValue = fieldMirror.get + s" $paramName:\t$paramValue" + }.mkString("{\n", ",\n", "\n}") + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a6f78d2441db1..1edd2432a0352 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index d6b2fe430e5a4..e49129c4e7844 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} object Correlations { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index 6a3b0241ced7f..cb1abbd18fd4d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -43,6 +43,7 @@ import org.apache.spark.{SparkConf, SparkContext} */ object CosineSimilarity { case class Params(inputFile: String = null, threshold: Double = 0.1) + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4adc91d2fbe65..837d0591478c5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -62,7 +62,7 @@ object DecisionTreeRunner { minInfoGain: Double = 0.0, numTrees: Int = 1, featureSubsetStrategy: String = "auto", - fracTest: Double = 0.2) + fracTest: Double = 0.2) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -138,9 +138,11 @@ object DecisionTreeRunner { def run(params: Params) { - val conf = new SparkConf().setAppName("DecisionTreeRunner") + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") val sc = new SparkContext(conf) + println(s"DecisionTreeRunner with parameters:\n$params") + // Load training data and cache it. val origExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() @@ -235,7 +237,10 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain) if (params.numTrees == 1) { + val startTime = System.nanoTime() val model = DecisionTree.train(training, strategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.numNodes < 20) { println(model.toDebugString) // Print full model. } else { @@ -259,8 +264,11 @@ object DecisionTreeRunner { } else { val randomSeed = Utils.random.nextInt() if (params.algo == Classification) { + val startTime = System.nanoTime() val model = RandomForest.trainClassifier(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { @@ -275,8 +283,11 @@ object DecisionTreeRunner { println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { + val startTime = System.nanoTime() val model = RandomForest.trainRegressor(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 89dfa26c2299c..11e35598baf50 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -44,7 +44,7 @@ object DenseKMeans { input: String = null, k: Int = -1, numIterations: Int = 10, - initializationMode: InitializationMode = Parallel) + initializationMode: InitializationMode = Parallel) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 05b7d66f8dffd..e1f9622350135 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -47,7 +47,7 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 98aaedb9d7dc9..fc6678013b932 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { rank: Int = 10, numUserBlocks: Int = -1, numProductBlocks: Int = -1, - implicitPrefs: Boolean = false) + implicitPrefs: Boolean = false) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 4532512c01f84..6e4e2d07f284b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext} object MultivariateSummarizer { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index f01b8266e3fe3..663c12734af68 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._ object SampledRDDs { case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 952fa2a5109a4..f1ff4e6911f5e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -37,7 +37,7 @@ object SparseNaiveBayes { input: String = null, minPartitions: Int = 0, numFeatures: Int = -1, - lambda: Double = 1.0) + lambda: Double = 1.0) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() From add174aa56d291bc48ef73a42c39428c923efe31 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 8 Oct 2014 15:19:19 -0700 Subject: [PATCH 225/315] [SPARK-3843][Minor] Cleanup scalastyle.txt at the end of running dev/scalastyle dev/scalastyle create a log file 'scalastyle.txt'. it is overwrote per running but never deleted even though dev/mima and dev/lint-python delete their log files. Author: Kousuke Saruta Closes #2702 from sarutak/scalastyle-txt-cleanup and squashes the following commits: d6e238e [Kousuke Saruta] Fixed dev/scalastyle to cleanup scalastyle.txt --- dev/scalastyle | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/scalastyle b/dev/scalastyle index efb5f291ea3b7..c3b356bcb3c06 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -26,6 +26,8 @@ echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalasty >> scalastyle.txt ERRORS=$(cat scalastyle.txt | grep -e "\") +rm scalastyle.txt + if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 From a85f24accd3266e0f97ee04d03c22b593d99c062 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 8 Oct 2014 17:03:47 -0700 Subject: [PATCH 226/315] [SPARK-3831] [SQL] Filter rule Improvement and bool expression optimization. If we write the filter which is always FALSE like SELECT * from person WHERE FALSE; 200 tasks will run. I think, 1 task is enough. And current optimizer cannot optimize the case NOT is duplicated like SELECT * from person WHERE NOT ( NOT (age > 30)); The filter rule above should be simplified Author: Kousuke Saruta Closes #2692 from sarutak/SPARK-3831 and squashes the following commits: 25f3e20 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3831 23c750c [Kousuke Saruta] Improved unsupported predicate test case a11b9f3 [Kousuke Saruta] Modified NOT predicate test case in PartitionBatchPruningSuite 8ea872b [Kousuke Saruta] Fixed the number of tasks when the data of LocalRelation is empty. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 12 ++++++++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 3 ++- .../sql/columnar/PartitionBatchPruningSuite.scala | 3 ++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4133feae8166..636d0b95583e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -299,6 +299,18 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, _) => or } + case not @ Not(exp) => + exp match { + case Literal(true, BooleanType) => Literal(false) + case Literal(false, BooleanType) => Literal(true) + case GreaterThan(l, r) => LessThanOrEqual(l, r) + case GreaterThanOrEqual(l, r) => LessThan(l, r) + case LessThan(l, r) => GreaterThanOrEqual(l, r) + case LessThanOrEqual(l, r) => GreaterThan(l, r) + case Not(e) => e + case _ => not + } + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c16d0c624128..883f2ff521e20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -274,9 +274,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => + val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 69e0adbd3ee0d..f53acc8c9f718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -67,10 +67,11 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) // With unsupported predicate checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) def checkBatchPruning( filter: String, From a42cc08d219c579019f613faa8d310e6069c06fe Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Oct 2014 17:04:49 -0700 Subject: [PATCH 227/315] [SPARK-3713][SQL] Uses JSON to serialize DataType objects This PR uses JSON instead of `toString` to serialize `DataType`s. The latter is not only hard to parse but also flaky in many cases. Since we already write schema information to Parquet metadata in the old style, we have to reserve the old `DataType` parser and ensure downward compatibility. The old parser is now renamed to `CaseClassStringParser` and moved into `object DataType`. JoshRosen davies Please help review PySpark related changes, thanks! Author: Cheng Lian Closes #2563 from liancheng/datatype-to-json and squashes the following commits: fc92eb3 [Cheng Lian] Reverts debugging code, simplifies primitive type JSON representation 438c75f [Cheng Lian] Refactors PySpark DataType JSON SerDe per comments 6b6387b [Cheng Lian] Removes debugging code 6a3ee3a [Cheng Lian] Addresses per review comments dc158b5 [Cheng Lian] Addresses PEP8 issues 99ab4ee [Cheng Lian] Adds compatibility est case for Parquet type conversion a983a6c [Cheng Lian] Adds PySpark support f608c6e [Cheng Lian] De/serializes DataType objects from/to JSON --- python/pyspark/sql.py | 153 ++++++------ .../catalyst/expressions/WrapDynamic.scala | 4 +- .../spark/sql/catalyst/types/dataTypes.scala | 229 ++++++++++++------ .../org/apache/spark/sql/SQLContext.scala | 9 +- .../spark/sql/parquet/ParquetTypes.scala | 6 +- .../org/apache/spark/sql/DataTypeSuite.scala | 28 +++ .../spark/sql/parquet/ParquetQuerySuite.scala | 16 +- 7 files changed, 277 insertions(+), 168 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3d5a281239d66..d3d36eb995ab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -34,6 +34,7 @@ import datetime import keyword import warnings +import json from array import array from operator import itemgetter from itertools import imap @@ -71,6 +72,18 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + class PrimitiveTypeSingleton(type): @@ -214,6 +227,16 @@ def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + class MapType(DataType): @@ -254,6 +277,18 @@ def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + class StructField(DataType): @@ -292,6 +327,17 @@ def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"]) + class StructType(DataType): @@ -321,42 +367,30 @@ def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} -def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types.""" - index = 0 - datatype_list = [] - start = 0 - depth = 0 - while index < len(datatype_list_string): - if depth == 0 and datatype_list_string[index] == ",": - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - start = index + 1 - elif datatype_list_string[index] == "(": - depth += 1 - elif datatype_list_string[index] == ")": - depth -= 1 + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) - index += 1 - # Handle the last data type - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - return datatype_list +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) -_all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) -def _parse_datatype_string(datatype_string): - """Parses the given data type string. - +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) - ... python_datatype = _parse_datatype_string( - ... scala_datatype.toString()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True @@ -394,51 +428,14 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_maptype) True """ - index = datatype_string.find("(") - if index == -1: - # It is a primitive type. - index = len(datatype_string) - type_or_field = datatype_string[:index] - rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() - - if type_or_field in _all_primitive_types: - return _all_primitive_types[type_or_field]() - - elif type_or_field == "ArrayType": - last_comma_index = rest_part.rfind(",") - containsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - containsNull = False - elementType = _parse_datatype_string( - rest_part[:last_comma_index].strip()) - return ArrayType(elementType, containsNull) - - elif type_or_field == "MapType": - last_comma_index = rest_part.rfind(",") - valueContainsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - valueContainsNull = False - keyType, valueType = _parse_datatype_list( - rest_part[:last_comma_index].strip()) - return MapType(keyType, valueType, valueContainsNull) - - elif type_or_field == "StructField": - first_comma_index = rest_part.find(",") - name = rest_part[:first_comma_index].strip() - last_comma_index = rest_part.rfind(",") - nullable = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - nullable = False - dataType = _parse_datatype_string( - rest_part[first_comma_index + 1:last_comma_index].strip()) - return StructField(name, dataType, nullable) - - elif type_or_field == "StructType": - # rest_part should be in the format like - # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(") + 1:-1] - fields = _parse_datatype_list(field_list_string) - return StructType(fields) + return _parse_datatype_json_value(json.loads(json_string)) + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode and json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + else: + return _all_complex_types[json_value["type"]].fromJson(json_value) # Mapping Python types to Spark SQL DateType @@ -992,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc.pythonExec, broadcast_vars, self._sc._javaAccumulator, - str(returnType)) + returnType.json()) def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}. @@ -1128,7 +1125,7 @@ def applySchema(self, rdd, schema): batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): @@ -1218,7 +1215,7 @@ def jsonFile(self, path, schema=None): if schema is None: srdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1288,7 +1285,7 @@ def func(iterator): if schema is None: srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1623,7 +1620,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) + return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) def schemaString(self): """Returns the output schema in the tree format.""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 1eb55715794a7..1a4ac06c7a79d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType /** * The data type representing [[DynamicRow]] values. */ -case object DynamicType extends DataType { - def simpleString: String = "dynamic" -} +case object DynamicType extends DataType /** * Wrap a [[Row]] as a [[DynamicRow]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ac043d4dd8eb9..1d375b8754182 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,71 +19,125 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} +import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers +import org.json4s.JsonAST.JValue +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils -/** - * Utility functions for working with DataTypes. - */ -object DataType extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "StringType" ^^^ StringType | - "FloatType" ^^^ FloatType | - "IntegerType" ^^^ IntegerType | - "ByteType" ^^^ ByteType | - "ShortType" ^^^ ShortType | - "DoubleType" ^^^ DoubleType | - "LongType" ^^^ LongType | - "BinaryType" ^^^ BinaryType | - "BooleanType" ^^^ BooleanType | - "DecimalType" ^^^ DecimalType | - "TimestampType" ^^^ TimestampType - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) +object DataType { + def fromJson(json: String): DataType = parseDataType(parse(json)) + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + PrimitiveType.nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + } - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + @deprecated("Use DataType.fromJson instead") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DecimalType" ^^^ DecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) - } + } - protected lazy val boolVal: Parser[Boolean] = - "true" ^^^ true | - "false" ^^^ false + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => new StructType(fields) - } + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } protected[types] def buildFormattedString( @@ -111,15 +165,19 @@ abstract class DataType { def isPrimitive: Boolean = false - def simpleString: String -} + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName -case object NullType extends DataType { - def simpleString: String = "null" + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) } +case object NullType extends DataType + object NativeType { - def all = Seq( + val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) @@ -139,6 +197,12 @@ trait PrimitiveType extends DataType { override def isPrimitive = true } +object PrimitiveType { + private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all + + private[sql] val nameToType = all.map(t => t.typeName -> t).toMap +} + abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] @@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "string" } case object BinaryType extends NativeType with PrimitiveType { @@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType { val res = x(i).compareTo(y(i)) if (res != 0) return res } - return x.length - y.length + x.length - y.length } } - def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "boolean" } case object TimestampType extends NativeType { @@ -187,8 +248,6 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } - - def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -222,7 +281,6 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "long" } case object IntegerType extends IntegralType { @@ -231,7 +289,6 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "integer" } case object ShortType extends IntegralType { @@ -240,7 +297,6 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "short" } case object ByteType extends IntegralType { @@ -249,7 +305,6 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -271,7 +326,6 @@ case object DecimalType extends FractionalType { private[sql] val fractional = implicitly[Fractional[BigDecimal]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = BigDecimalAsIfIntegral - def simpleString: String = "decimal" } case object DoubleType extends FractionalType { @@ -281,7 +335,6 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral - def simpleString: String = "double" } case object FloatType extends FractionalType { @@ -291,12 +344,12 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral - def simpleString: String = "float" } object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) + def typeName: String = "array" } /** @@ -309,11 +362,14 @@ object ArrayType { case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( - s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") DataType.buildFormattedString(elementType, s"$prefix |", builder) } - def simpleString: String = "array" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) } /** @@ -325,14 +381,22 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case class StructField(name: String, dataType: DataType, nullable: Boolean) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) + } } object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + + def typeName = "struct" } case class StructType(fields: Seq[StructField]) extends DataType { @@ -348,8 +412,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - nameToField.get(name).getOrElse( - throw new IllegalArgumentException(s"Field ${name} does not exist.")) + nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist.")) } /** @@ -358,7 +421,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet - if (!nonExistFields.isEmpty) { + if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( s"Field ${nonExistFields.mkString(",")} does not exist.") } @@ -384,7 +447,9 @@ case class StructType(fields: Seq[StructField]) extends DataType { fields.foreach(field => field.buildFormattedString(prefix, builder)) } - def simpleString: String = "struct" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> fields.map(_.jsonValue)) } object MapType { @@ -394,6 +459,8 @@ object MapType { */ def apply(keyType: DataType, valueType: DataType): MapType = MapType(keyType: DataType, valueType: DataType, true) + + def simpleName = "map" } /** @@ -407,12 +474,16 @@ case class MapType( valueType: DataType, valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString} " + - s"(valueContainsNull = ${valueContainsNull})\n") + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } - def simpleString: String = "map" + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7a55c5bf97a71..35561cac3e5e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection @@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.SparkStrategies +import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: @@ -409,8 +409,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It is only used by PySpark. */ private[sql] def parseDataType(dataTypeString: String): DataType = { - val parser = org.apache.spark.sql.catalyst.types.DataType - parser(dataTypeString) + DataType.fromJson(dataTypeString) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2941b9793597f..e6389cf77a4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging { } def convertFromString(string: String): Seq[Attribute] = { - DataType(string) match { + Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - StructType.fromAttributes(schema).toString + StructType.fromAttributes(schema).json } def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 8fb59c5830f6d..100ecb45e9e88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.types.DataType + class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite { struct(Set("b", "d", "e", "f")) } } + + def checkDataTypeJsonRepr(dataType: DataType): Unit = { + test(s"JSON - $dataType") { + assert(DataType.fromJson(dataType.json) === dataType) + } + } + + checkDataTypeJsonRepr(BooleanType) + checkDataTypeJsonRepr(ByteType) + checkDataTypeJsonRepr(ShortType) + checkDataTypeJsonRepr(IntegerType) + checkDataTypeJsonRepr(LongType) + checkDataTypeJsonRepr(FloatType) + checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(TimestampType) + checkDataTypeJsonRepr(StringType) + checkDataTypeJsonRepr(BinaryType) + checkDataTypeJsonRepr(ArrayType(DoubleType, true)) + checkDataTypeJsonRepr(ArrayType(StringType, false)) + checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) + checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeJsonRepr( + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 07adf731405af..25e41ecf28e2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } - + test("Querying on empty parquet throws exception (SPARK-3536)") { val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) @@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result1.size === 0) Utils.deleteRecursively(tmpdir) } + + test("DataType string parser compatibility") { + val schema = StructType(List( + StructField("c1", IntegerType, false), + StructField("c2", BinaryType, false))) + + val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString) + val fromJson = ParquetTypesConverter.convertFromString(schema.json) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + } + } } From 00b7791720e50119a98084b2e8755e1b593ca55f Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Wed, 8 Oct 2014 17:16:54 -0700 Subject: [PATCH 228/315] [SQL][Doc] Keep Spark SQL README.md up to date marmbrus Update README.md to be consistent with Spark 1.1 Author: Liquan Pei Closes #2706 from Ishiihara/SparkSQL-readme and squashes the following commits: 33b9d4b [Liquan Pei] keep README.md up to date --- sql/README.md | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/sql/README.md b/sql/README.md index 31f9152344086..c84534da9a3d3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -44,38 +44,37 @@ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.ExecutedQuery = -SELECT * FROM (SELECT * FROM src) a -=== Query Plan === -Project [key#6:0.0,value#7:0.1] - HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None +query: org.apache.spark.sql.SchemaRDD = +== Query Plan == +== Physical Plan == +HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None ``` Query results are RDDs and can be operated as such. ``` scala> query.collect() -res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]... +res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` You can also build further queries on top of these RDDs using the query DSL. ``` -scala> query.where('key === 100).toRdd.collect() -res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100]) +scala> query.where('key === 100).collect() +res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100]) ``` -From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects. +From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects. ```scala -scala> query.logicalPlan -res1: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} - Project {key#0,value#1} +scala> query.queryExecution.analyzed +res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] + Project [key#10,value#11] MetastoreRelation default, src, None -scala> query.logicalPlan transform { +scala> query.queryExecution.analyzed transform { | case Project(projectList, child) if projectList == child.output => child | } -res2: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} +res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] MetastoreRelation default, src, None ``` From 4ec931951fea4efbfe5db39cf581704df7d2775b Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 8 Oct 2014 17:52:27 -0700 Subject: [PATCH 229/315] [SPARK-3707] [SQL] Fix bug of type coercion in DIV Calling `BinaryArithmetic.dataType` will throws exception until it's resolved, but in type coercion rule `Division`, seems doesn't follow this. Author: Cheng Hao Closes #2559 from chenghao-intel/type_coercion and squashes the following commits: 199a85d [Cheng Hao] Simplify the divide rule dc55218 [Cheng Hao] fix bug of type coercion in div --- .../catalyst/analysis/HiveTypeCoercion.scala | 7 +++- .../sql/catalyst/analysis/AnalysisSuite.scala | 40 +++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 79e5283e86a37..64881854df7a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -348,8 +348,11 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e // Decimal and Double remain the same - case d: Divide if d.dataType == DoubleType => d - case d: Divide if d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType == DoubleType => d + case d: Divide if d.resolved && d.dataType == DecimalType => d + + case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) + case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5809a108ff62e..7b45738c4fc95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) @@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) before { caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) @@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val e = intercept[RuntimeException] { caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) } - assert(e.getMessage === "Table Not Found: tAbLe") + assert(e.getMessage == "Table Not Found: tAbLe") assert( caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === @@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } assert(e.getMessage().toLowerCase.contains("unresolved plan")) } + + test("divide should be casted into fractional types") { + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) + + val expr0 = 'a / 2 + val expr1 = 'a / 'b + val expr2 = 'a / 'c + val expr3 = 'a / 'd + val expr4 = 'e / 'e + val plan = caseInsensitiveAnalyze(Project( + Alias(expr0, s"Analyzer($expr0)")() :: + Alias(expr1, s"Analyzer($expr1)")() :: + Alias(expr2, s"Analyzer($expr2)")() :: + Alias(expr3, s"Analyzer($expr3)")() :: + Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) + assert(pl(1).dataType == DoubleType) + assert(pl(2).dataType == DoubleType) + assert(pl(3).dataType == DecimalType) + assert(pl(4).dataType == DoubleType) + } } From e7033572330bd48b2438f218b0d2cd3fccdeb362 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Oct 2014 18:11:18 -0700 Subject: [PATCH 230/315] [SPARK-3810][SQL] Makes PreInsertionCasts handle partitions properly Includes partition keys into account when applying `PreInsertionCasts` rule. Author: Cheng Lian Closes #2672 from liancheng/fix-pre-insert-casts and squashes the following commits: def1a1a [Cheng Lian] Makes PreInsertionCasts handle partitions properly --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 15 +++----- .../sql/hive/execution/HiveQuerySuite.scala | 36 +++++++++++++++++++ 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index cc0605b0adb35..addd5bed8426d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,31 +19,28 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, SerDeInfo, StorageDescriptor, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} +import org.apache.spark.sql.catalyst.analysis.Catalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { - import HiveMetastoreTypes._ + import org.apache.spark.sql.hive.HiveMetastoreTypes._ /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) @@ -137,10 +134,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { val childOutputDataTypes = child.output.map(_.dataType) - // Only check attributes, not partitionKeys since they are always strings. - // TODO: Fully support inserting into partitioned tables. val tableOutputDataTypes = - table.attributes.map(_.dataType) ++ table.partitionKeys.map(_.dataType) + (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { p diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2e282a9ade40c..2829105f43716 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -675,6 +676,41 @@ class HiveQuerySuite extends HiveComparisonTest { sql("SELECT * FROM boom").queryExecution.analyzed } + test("SPARK-3810: PreInsertionCasts static partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + + test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" From 3e4f09d2fce9dcf45eaaca827f2cf15c9d4a6c75 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 8 Oct 2014 18:13:22 -0700 Subject: [PATCH 231/315] [SQL] Prevents per row dynamic dispatching and pattern matching when inserting Hive values Builds all wrappers at first according to object inspector types to avoid per row costs. Author: Cheng Lian Closes #2592 from liancheng/hive-value-wrapper and squashes the following commits: 9696559 [Cheng Lian] Passes all tests 4998666 [Cheng Lian] Prevents per row dynamic dispatching and pattern matching when inserting Hive values --- .../hive/execution/InsertIntoHiveTable.scala | 64 ++++++++++--------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f8b4e898ec41d..f0785d8882636 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,33 +69,36 @@ case class InsertIntoHiveTable( * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. */ - protected def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) - - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) - - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct } - struct - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - case (m: Map[_, _], oi: MapObjectInspector) => - val keyOi = oi.getMapKeyObjectInspector - val valueOi = oi.getMapValueObjectInspector - val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } - mapAsJavaMap(wrappedMap) + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) - case (obj, _) => - obj + case _ => + identity[Any] } def saveAsHiveFile( @@ -103,7 +106,7 @@ case class InsertIntoHiveTable( valueClass: Class[_], fileSinkConf: FileSinkDesc, conf: SerializableWritable[JobConf], - writerContainer: SparkHiveWriterContainer) { + writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -122,7 +125,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -131,6 +134,7 @@ case class InsertIntoHiveTable( .asInstanceOf[StructObjectInspector] val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it @@ -141,13 +145,13 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` - outputData(i) = wrap(row(i), fieldOIs(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) i += 1 } - val writer = writerContainer.getLocalFileWriter(row) - writer.write(serializer.serialize(outputData, standardOI)) + writerContainer + .getLocalFileWriter(row) + .write(serializer.serialize(outputData, standardOI)) } writerContainer.close() @@ -207,7 +211,7 @@ case class InsertIntoHiveTable( // Report error if any static partition appears after a dynamic partition val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } From bcb1ae049b447c37418747e0a262f54f9fc1664a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 8 Oct 2014 18:17:01 -0700 Subject: [PATCH 232/315] [SPARK-3857] Create joins package for various join operators. Author: Reynold Xin Closes #2719 from rxin/sql-join-break and squashes the following commits: 0c0082b [Reynold Xin] Fix line length. cbc664c [Reynold Xin] Rename join -> joins package. a070d44 [Reynold Xin] Fix line length in HashJoin a39be8c [Reynold Xin] [SPARK-3857] Create a join package for various join operators. --- .../spark/sql/execution/SparkStrategies.scala | 41 +- .../apache/spark/sql/execution/joins.scala | 624 ------------------ .../execution/joins/BroadcastHashJoin.scala | 62 ++ .../joins/BroadcastNestedLoopJoin.scala | 144 ++++ .../execution/joins/CartesianProduct.scala | 40 ++ .../spark/sql/execution/joins/HashJoin.scala | 123 ++++ .../sql/execution/joins/HashOuterJoin.scala | 222 +++++++ .../sql/execution/joins/LeftSemiJoinBNL.scala | 73 ++ .../execution/joins/LeftSemiJoinHash.scala | 67 ++ .../execution/joins/ShuffledHashJoin.scala | 49 ++ .../spark/sql/execution/joins/package.scala | 37 ++ .../org/apache/spark/sql/JoinSuite.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 2 +- 15 files changed, 844 insertions(+), 646 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 883f2ff521e20..bbf17b9fadf86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ + private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -34,13 +35,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = execution.LeftSemiJoinHash( + val semiJoin = joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => - execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition) :: Nil + joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -50,13 +50,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. * * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an * estimated physical size smaller than the user-settable threshold * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. + * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -66,8 +66,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: BuildSide) = { - val broadcastHashJoin = execution.BroadcastHashJoin( + side: joins.BuildSide) = { + val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -76,27 +76,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - BuildRight + joins.BuildRight } else { - BuildLeft + joins.BuildLeft } - val hashJoin = - execution.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + val hashJoin = joins.ShuffledHashJoin( + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - execution.HashOuterJoin( + joins.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -164,8 +163,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft - execution.BroadcastNestedLoopJoin( + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } @@ -174,10 +177,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => - execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, - execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala deleted file mode 100644 index 2890a563bed48..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ /dev/null @@ -1,624 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.{HashMap => JavaHashMap} - -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.collection.CompactBuffer - -@DeveloperApi -sealed abstract class BuildSide - -@DeveloperApi -case object BuildLeft extends BuildSide - -@DeveloperApi -case object BuildRight extends BuildSide - -trait HashJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val buildSide: BuildSide - val left: SparkPlan - val right: SparkPlan - - lazy val (buildPlan, streamedPlan) = buildSide match { - case BuildLeft => (left, right) - case BuildRight => (right, left) - } - - lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) - } - - def output = left.output ++ right.output - - @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - newMutableProjection(streamedKeys, streamedPlan.output) - - def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) - - override final def next() = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - ret - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - private[this] def leftOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in right side. - // If we didn't get any proper row, then append a single row with empty right - joinedRow.withRight(rightNullRow).copy - }) - } - } - - private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - rightIter.iterator.flatMap { r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in left side. - // If we didn't get any proper row, then append a single row with empty left. - joinedRow.withLeft(leftNullRow).copy - }) - } - } - - private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - joinedRow.copy - } - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } - } - } else { - leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy - } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy - } - } - } - - private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() - while (iter.hasNext) { - val currentRow = iter.next() - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } - - def execute() = { - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - // Build HashMap for current partition in left relation - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - // Build HashMap for current partition in right relation - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - - import scala.collection.JavaConversions._ - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - joinType match { - case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -@DeveloperApi -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) - } - } -} - -/** - * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -@DeveloperApi -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - val buildSide = BuildRight - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = left.output - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } - } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) - } - } -} - - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -@DeveloperApi -case class BroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - override def requiredChildDistribution = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - @transient - val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) - } - - def execute() = { - val broadcastRelation = Await.result(broadcastFuture, 5.minute) - - streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) - } - } -} - -/** - * :: DeveloperApi :: - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -@DeveloperApi -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - def output = left.output - - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matched = true - } - i += 1 - } - matched - }) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - def output = left.output ++ right.output - - def execute() = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class BroadcastNestedLoopJoin( - left: SparkPlan, - right: SparkPlan, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. - - /** BuildRight means the right relation <=> the broadcast relation. */ - val (streamed, broadcast) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output - } - } - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - - streamedIter.foreach { streamedRow => - var i = 0 - var streamRowMatched = false - - while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. - val broadcastedRow = broadcastedRelation.value(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case _ => - } - i += 1 - } - - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() - case _ => - } - } - Iterator((matchedRows, includedBroadcastTuples)) - } - - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() - var i = 0 - val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) - case _ => - } - } - i += 1 - } - buf.toSeq - } - - // TODO: Breaks lineage. - sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala new file mode 100644 index 0000000000000..d88ab6367a1b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + private val broadcastFuture = future { + sparkContext.broadcast(buildPlan.executeCollect()) + } + + override def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + joinIterators(broadcastRelation.value.iterator, streamedIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala new file mode 100644 index 0000000000000..36aad13778bd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class BroadcastNestedLoopJoin( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + // TODO: Override requiredChildDistribution. + + /** BuildRight means the right relation <=> the broadcast relation. */ + private val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new CompactBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + + streamedIter.foreach { streamedRow => + var i = 0 + var streamRowMatched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => + } + i += 1 + } + + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + includedBroadcastTuples.reduce(_ ++ _) + } + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val buf: CompactBuffer[Row] = new CompactBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case _ => + } + } + i += 1 + } + buf.toSeq + } + + // TODO: Breaks lineage. + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala new file mode 100644 index 0000000000000..76c14c02aab34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def output = left.output ++ right.output + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala new file mode 100644 index 0000000000000..472b2e6ca6b4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.collection.CompactBuffer + + +trait HashJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan + + protected lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output = left.output ++ right.output + + @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) + @transient protected lazy val streamSideKeyGenerator = + newMutableProjection(streamedKeys, streamedPlan.output) + + protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + { + // TODO: Use Spark's HashMap implementation. + + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow2 + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala new file mode 100644 index 0000000000000..b73041d306b36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + + existingMatchList += currentRow.copy() + } + + hashTable + } + + override def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala new file mode 100644 index 0000000000000..60003d1900d85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = left.output + + /** The Streamed Relation */ + override def left = streamed + /** The Broadcast relation */ + override def right = broadcast + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala new file mode 100644 index 0000000000000..ea7babf3be948 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = left.output + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..8247304c1dc2c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. + */ +@DeveloperApi +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { + (buildIter, streamIter) => joinIterators(buildIter, streamIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala new file mode 100644 index 0000000000000..7f2ab1765b28f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Physical execution operators for join operations. + */ +package object joins { + + @DeveloperApi + sealed abstract class BuildSide + + @DeveloperApi + case object BuildRight extends BuildSide + + @DeveloperApi + case object BuildLeft extends BuildSide + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6c7697ece8c56..07f4d2946c1b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6fb6cb8db0c8f..b9b196ea5a46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bfbf431a11913..f14ffca0e4d35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite +import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a35c40efdc207..14e791fe0f0ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.NativeCommand -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ From f706823b71c763fa8e8ceb9e1bd916d8dca7a639 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Oct 2014 22:25:15 -0700 Subject: [PATCH 233/315] Fetch from branch v4 in Spark EC2 script. --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 27f468ea4f395..0d6b82b4944f3 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -583,7 +583,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v3") + ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") print "Deploying files to master..." deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) From 9c439d33160ef3b31173381735dfa8cfb7d552ba Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Oct 2014 22:35:14 -0700 Subject: [PATCH 234/315] [SPARK-3856][MLLIB] use norm operator after breeze 0.10 upgrade Got warning msg: ~~~ [warn] /Users/meng/src/spark/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala:50: method norm in trait NumericOps is deprecated: Use norm(XXX) instead of XXX.norm [warn] var norm = vector.toBreeze.norm(p) ~~~ dbtsai Author: Xiangrui Meng Closes #2718 from mengxr/SPARK-3856 and squashes the following commits: 4f38169 [Xiangrui Meng] use norm operator --- .../scala/org/apache/spark/mllib/feature/Normalizer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 3afb47767281c..4734251127bb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm} import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -47,7 +47,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = vector.toBreeze.norm(p) + var norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. From b9df8af62e8d7b263a668dfb6e9668ab4294ea37 Mon Sep 17 00:00:00 2001 From: Anand Avati Date: Wed, 8 Oct 2014 23:45:17 -0700 Subject: [PATCH 235/315] [SPARK-2805] Upgrade to akka 2.3.4 Upgrade to akka 2.3.4 Author: Anand Avati Closes #1685 from avati/SPARK-1812-akka-2.3 and squashes the following commits: 57a2315 [Anand Avati] SPARK-1812: streaming - remove tests which depend on akka.actor.IO 2a551d3 [Anand Avati] SPARK-1812: core - upgrade to akka 2.3.4 --- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/AppClient.scala | 2 +- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- pom.xml | 2 +- .../spark/streaming/InputStreamsSuite.scala | 71 ------------------- 6 files changed, 6 insertions(+), 77 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/pom.xml b/pom.xml index 7756c89b00cad..3b6d4ecbae2c1 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 1.0.4 diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 952a74fd5f6de..6107fcdc447b6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -144,59 +142,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -378,22 +323,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { From 86b392942daf61fed2ff7490178b128107a0e856 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 9 Oct 2014 00:00:24 -0700 Subject: [PATCH 236/315] [SPARK-3844][UI] Truncate appName in WebUI if it is too long Truncate appName in WebUI if it is too long. Author: Xiangrui Meng Closes #2707 from mengxr/truncate-app-name and squashes the following commits: 87834ce [Xiangrui Meng] move scala import below java c7111dc [Xiangrui Meng] truncate appName in WebUI if it is too long --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index be69060fc3bf8..32e6b15bb0999 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.text.SimpleDateFormat import java.util.{Locale, Date} import scala.xml.Node + import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ @@ -169,6 +170,7 @@ private[spark] object UIUtils extends Logging { refreshInterval: Option[Int] = None): Seq[Node] = { val appName = activeTab.appName + val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • {tab.name} @@ -187,7 +189,9 @@ private[spark] object UIUtils extends Logging { - +
  • From 13cab5ba44e2f8d2d2204b3b0d39d7c23a819bdb Mon Sep 17 00:00:00 2001 From: nartz Date: Thu, 9 Oct 2014 00:02:11 -0700 Subject: [PATCH 237/315] add spark.driver.memory to config docs It took me a minute to track this down, so I thought it could be useful to have it in the docs. I'm unsure if 512mb is the default for spark.driver.memory? Also - there could be a better value for the 'description' to differentiate it from spark.executor.memory. Author: nartz Author: Nathan Artz Closes #2410 from nartz/docs/add-spark-driver-memory-to-config-docs and squashes the following commits: a2f6c62 [nartz] Update configuration.md 74521b8 [Nathan Artz] add spark.driver.memory to config docs --- docs/configuration.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 1c33855365170..f311f0d2a6206 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -103,6 +103,14 @@ of the most common options to set are: (e.g. 512m, 2g).
    spark.driver.memory512m + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). +
    spark.serializer org.apache.spark.serializer.
    JavaSerializer
    {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}{UIUtils.formatDuration(v.taskTime)}{UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} + {Utils.bytesToString(v.inputBytes)} + {Utils.bytesToString(v.shuffleRead)} + {Utils.bytesToString(v.shuffleWrite)} + {Utils.bytesToString(v.memoryBytesSpilled)} + {Utils.bytesToString(v.diskBytesSpilled)}
    {inputReadWithUnit}{shuffleReadWithUnit}{shuffleWriteWithUnit}{inputReadWithUnit}{shuffleReadWithUnit}{shuffleWriteWithUnit} {rdd.numCachedPartitions} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)}{Utils.bytesToString(rdd.memSize)}{Utils.bytesToString(rdd.tachyonSize)}{Utils.bytesToString(rdd.diskSize)}{Utils.bytesToString(rdd.memSize)}{Utils.bytesToString(rdd.tachyonSize)}{Utils.bytesToString(rdd.diskSize)}
    spark.history.fs.logDirectory(none) + Directory that contains application event logs to be loaded by the history server +
    spark.history.fs.updateInterval 10
    spark.akka.heartbeat.pauses6006000 This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause From be2ec4a91d14f48e6323989fb0e0226a9d65bf7e Mon Sep 17 00:00:00 2001 From: Kun Li Date: Thu, 16 Oct 2014 19:00:10 -0700 Subject: [PATCH 301/315] [SQL]typo in HiveFromSpark Author: Kun Li Closes #2809 from jackylk/patch-1 and squashes the following commits: 46c926b [Kun Li] typo in HiveFromSpark --- .../org/apache/spark/examples/sql/hive/HiveFromSpark.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index e26f213e8afa8..0c52ef8ed96ac 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -28,7 +28,7 @@ object HiveFromSpark { val sparkConf = new SparkConf().setAppName("HiveFromSpark") val sc = new SparkContext(sparkConf) - // A local hive context creates an instance of the Hive Metastore in process, storing the + // A local hive context creates an instance of the Hive Metastore in process, storing // the warehouse data in the current directory. This location can be overridden by // specifying a second parameter to the constructor. val hiveContext = new HiveContext(sc) From 642b246beb7879978d31f2e6e97de7e06c74dcb7 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Thu, 16 Oct 2014 19:07:37 -0700 Subject: [PATCH 302/315] [SPARK-3941][CORE] _remainingmem should not increase twice when updateBlockInfo In BlockManagermasterActor, _remainingMem would increase memSize for twice when updateBlockInfo if new storageLevel is invalid and old storageLevel is "useMemory". Also, _remainingMem should increase with original memory size instead of new memSize. Author: Zhang, Liye Closes #2792 from liyezhang556520/spark-3941-remainMem and squashes the following commits: 3d487cc [Zhang, Liye] make the code concise 0380a32 [Zhang, Liye] [SPARK-3941][CORE] _remainingmem should not increase twice when updateBlockInfo --- .../apache/spark/storage/BlockManagerMasterActor.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 6a06257ed0c08..088f06e389d83 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -457,16 +457,18 @@ private[spark] class BlockManagerInfo( if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + val blockStatus: BlockStatus = _blocks.get(blockId) + val originalLevel: StorageLevel = blockStatus.storageLevel + val originalMemSize: Long = blockStatus.memSize if (originalLevel.useMemory) { - _remainingMem += memSize + _remainingMem += originalMemSize } } if (storageLevel.isValid) { /* isValid means it is either stored in-memory, on-disk or on-Tachyon. - * But the memSize here indicates the data size in or dropped from memory, + * The memSize here indicates the data size in or dropped from memory, * tachyonSize here indicates the data size in or dropped from Tachyon, * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. @@ -493,7 +495,6 @@ private[spark] class BlockManagerInfo( val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), Utils.bytesToString(_remainingMem))) From e7f4ea8a52f0d3d56684b4f9caadce978eac4816 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 16 Oct 2014 19:12:39 -0700 Subject: [PATCH 303/315] [SPARK-3890][Docs]remove redundant spark.executor.memory in doc Introduced in https://github.com/pwendell/spark/commit/f7e79bc42c1635686c3af01eef147dae92de2529, I'm not sure why we need two spark.executor.memory here. Author: WangTaoTheTonic Author: WangTao Closes #2745 from WangTaoTheTonic/redundantconfig and squashes the following commits: e7564dc [WangTao] too long line fdbdb1f [WangTaoTheTonic] trivial workaround d06b6e5 [WangTaoTheTonic] remove redundant spark.executor.memory in doc --- docs/configuration.md | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 8515ee045177f..f0204c640bc89 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -161,14 +161,6 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment - - - - - @@ -365,7 +357,7 @@ Apart from these, the following properties are also available, and may be useful @@ -880,8 +872,8 @@ Apart from these, the following properties are also available, and may be useful @@ -893,7 +885,7 @@ Apart from these, the following properties are also available, and may be useful to wait for before scheduling begins. Specified as a double between 0 and 1. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredResourcesWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime. From 56fd34af52a18230bf3ea7b041f2a184eddc1103 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 16 Oct 2014 19:22:02 -0700 Subject: [PATCH 304/315] [SPARK-3741] Add afterExecute for handleConnectExecutor Sorry. I found that I forgot to add `afterExecute` for `handleConnectExecutor` in #2593. Author: zsxwing Closes #2794 from zsxwing/SPARK-3741 and squashes the following commits: a0bc4dd [zsxwing] Add afterExecute for handleConnectExecutor --- .../apache/spark/network/nio/ConnectionManager.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 9396b6ba84e7e..bda4bf50932c3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -117,7 +117,16 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.connect.threads.max", 8), conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) + Utils.namedThreadFactory("handle-connect-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleConnectExecutor is not handled properly", t) + } + } + + } private val serverChannel = ServerSocketChannel.open() // used to track the SendingConnections waiting to do SASL negotiation From dedace83f35cba0f833d962acbd75572318948c4 Mon Sep 17 00:00:00 2001 From: yantangzhai Date: Thu, 16 Oct 2014 19:25:37 -0700 Subject: [PATCH 305/315] [SPARK-3067] JobProgressPage could not show Fair Scheduler Pools section sometimes JobProgressPage could not show Fair Scheduler Pools section sometimes. SparkContext starts webui and then postEnvironmentUpdate. Sometimes JobProgressPage is accessed between webui starting and postEnvironmentUpdate, then the lazy val isFairScheduler will be false. The Fair Scheduler Pools section will not display any more. Author: yantangzhai Author: YanTangZhai Closes #1966 from YanTangZhai/SPARK-3067 and squashes the following commits: d4323f8 [yantangzhai] update [SPARK-3067] JobProgressPage could not show Fair Scheduler Pools section sometimes 8a00106 [YanTangZhai] Merge pull request #6 from apache/master b6391cc [yantangzhai] revert [SPARK-3067] JobProgressPage could not show Fair Scheduler Pools section sometimes d2226cd [yantangzhai] [SPARK-3067] JobProgressPage could not show Fair Scheduler Pools section sometimes cbcba66 [YanTangZhai] Merge pull request #3 from apache/master aac7f7b [yantangzhai] [SPARK-3067] JobProgressPage could not show Fair Scheduler Pools section sometimes cdef539 [YanTangZhai] Merge pull request #1 from apache/master --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b709b8880ba76..354116286c77d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -238,7 +238,6 @@ class SparkContext(config: SparkConf) extends Logging { // For tests, do not enable the UI None } - ui.foreach(_.bind()) /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) @@ -342,6 +341,10 @@ class SparkContext(config: SparkConf) extends Logging { postEnvironmentUpdate() postApplicationStart() + // Bind the SparkUI after starting the task scheduler + // because certain pages and listeners depend on it + ui.foreach(_.bind()) + private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack From e678b9f02a2936b35c95e91a5f0ff388b5720261 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 16 Oct 2014 19:43:33 -0700 Subject: [PATCH 306/315] [SPARK-3973] Print call site information for broadcasts Its hard to debug which broadcast variables refer to what in a big codebase. Printing call site information helps in debugging. Author: Shivaram Venkataraman Closes #2829 from shivaram/spark-broadcast-print and squashes the following commits: cd6dbdf [Shivaram Venkataraman] Print call site information for broadcasts --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 354116286c77d..dd3157990ef2d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -818,6 +818,8 @@ class SparkContext(config: SparkConf) extends Logging { */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + val callSite = getCallSite + logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc } From c351862064ed7d2031ea4c8bf33881e5f702ea0a Mon Sep 17 00:00:00 2001 From: likun Date: Fri, 17 Oct 2014 10:33:45 -0700 Subject: [PATCH 307/315] [SPARK-3935][Core] log the number of records that has been written There is a unused variable(count) in saveAsHadoopDataset in PairRDDFunctions.scala. The initial idea of this variable seems to count the number of records, so I am adding a log statement to log the number of records that has been written to the writer. Author: likun Author: jackylk Closes #2791 from jackylk/SPARK-3935 and squashes the following commits: a874047 [jackylk] removing the unused variable in PairRddFunctions.scala 3bf43c7 [likun] log the number of records has been written --- core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 929ded58a3bd5..ac96de86dd6d4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1032,10 +1032,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { - var count = 0 while (iter.hasNext) { val record = iter.next() - count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } } finally { From 803e7f087797bae643754f8db88848a17282ca6e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Oct 2014 13:45:10 -0500 Subject: [PATCH 308/315] [SPARK-3979] [yarn] Use fs's default replication. This avoids issues when HDFS is configured in a way that would not allow the hardcoded default replication of "3". Note: getDefaultReplication(Path) was added in 0.23.3, and the oldest one available on Maven Central is 0.23.7, so I chose to not add code to access that method via reflection. Author: Marcelo Vanzin Closes #2831 from vanzin/SPARK-3979 and squashes the following commits: b0e3a97 [Marcelo Vanzin] [SPARK-3979] [yarn] Use fs's default replication. --- .../main/scala/org/apache/spark/deploy/yarn/ClientBase.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 14a0386b78978..0efac4ea63702 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -143,7 +143,8 @@ private[spark] trait ClientBase extends Logging { val nns = getNameNodesToAccess(sparkConf) + dst obtainTokensForNamenodes(nns, hadoopConf, credentials) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", + fs.getDefaultReplication(dst)).toShort val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) From adcb7d3350032dda69a43de724c8bdff5fef2c67 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 17 Oct 2014 14:12:07 -0700 Subject: [PATCH 309/315] [SPARK-3855][SQL] Preserve the result attribute of python UDFs though transformations In the current implementation it was possible for the reference to change after analysis. Author: Michael Armbrust Closes #2717 from marmbrus/pythonUdfResults and squashes the following commits: da14879 [Michael Armbrust] Fix test 6343bcb [Michael Armbrust] add test 9533286 [Michael Armbrust] Correctly preserve the result attribute of python UDFs though transformations --- python/pyspark/tests.py | 6 ++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/execution/pythonUdfs.scala | 12 ++++++++++-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index ceab57464f013..f5ccf31abb3fa 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -683,6 +683,12 @@ def test_udf(self): [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(u"4", res[0]) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4f1af7234d551..79e4ddb8c4f5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -295,7 +295,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child) => + case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 0977da3e8577c..be729e5d244b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -105,13 +105,21 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { } } +object EvaluatePython { + def apply(udf: PythonUDF, child: LogicalPlan) = + new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) +} + /** * :: DeveloperApi :: * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ @DeveloperApi -case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { - val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { def output = child.output :+ resultAttribute } From 23f6171d633d4347ca4aa8ec7cb7bd57342b21b5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 17 Oct 2014 14:49:44 -0700 Subject: [PATCH 310/315] [SPARK-3985] [Examples] fix file path using os.path.join Author: Daoyuan Wang Closes #2834 from adrian-wang/sqlpypath and squashes the following commits: da7aa95 [Daoyuan Wang] fix file path using path.join --- examples/src/main/python/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index eefa022f1927c..d2c5ca48c6cb8 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -48,7 +48,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.environ['SPARK_HOME'] + "examples/src/main/resources/people.json" + path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") # Create a SchemaRDD from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root From 477c6481cca94b15c9c8b43e674f220a1cda1dd1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 17 Oct 2014 15:02:57 -0700 Subject: [PATCH 311/315] [SPARK-3934] [SPARK-3918] [mllib] Bug fixes for RandomForest, DecisionTree SPARK-3934: When run with a mix of unordered categorical and continuous features, on multiclass classification, RandomForest fails. The bug is in the sanity checks in getFeatureOffset and getLeftRightFeatureOffsets, which use the wrong indices for checking whether features are unordered. Fix: Remove the sanity checks since they are not really needed, and since they would require DTStatsAggregator to keep track of an extra set of indices (for the feature subset). Added test to RandomForestSuite which failed with old version but now works. SPARK-3918: Added baggedInput.unpersist at end of training. Also: * I removed DTStatsAggregator.isUnordered since it is no longer used. * DecisionTreeMetadata: Added logWarning when maxBins is automatically reduced. * Updated DecisionTreeRunner to explicitly fix the test data to have the same number of features as the training data. This is a temporary fix which should eventually be replaced by pre-indexing both datasets. * RandomForestModel: Updated toString to print total number of nodes in forest. * Changed Predict class to be public DeveloperApi. This was necessary to allow users to create their own trees by hand (for testing). CC: mengxr manishamde chouqin codedeft Just notifying you of these small bug fixes. Author: Joseph K. Bradley Closes #2785 from jkbradley/dtrunner-update and squashes the following commits: 9132321 [Joseph K. Bradley] merged with master, fixed imports 9dbd000 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update e116473 [Joseph K. Bradley] Changed Predict class to be public DeveloperApi. f502e65 [Joseph K. Bradley] bug fix for SPARK-3934 7f3d60f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update ba567ab [Joseph K. Bradley] Changed DTRunner to load test data using same number of features as in training data. 4e88c1f [Joseph K. Bradley] changed RF toString to print total number of nodes --- .../examples/mllib/DecisionTreeRunner.scala | 3 ++- .../mllib/tree/impl/DTStatsAggregator.scala | 16 +--------------- .../mllib/tree/impl/DecisionTreeMetadata.scala | 7 ++++++- .../apache/spark/mllib/tree/model/Predict.scala | 5 ++++- .../mllib/tree/model/RandomForestModel.scala | 4 ++-- .../spark/mllib/tree/RandomForestSuite.scala | 16 ++++++++++++++++ 6 files changed, 31 insertions(+), 20 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 837d0591478c5..0890e6263e165 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -189,9 +189,10 @@ object DecisionTreeRunner { // Create training, test sets. val splits = if (params.testInput != "") { // Load testInput. + val numFeatures = examples.take(1)(0).features.size val origTestExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) } params.algo match { case Classification => { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 55f422dff0d71..ce8825cc03229 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -64,12 +64,6 @@ private[tree] class DTStatsAggregator( numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } - /** - * Indicator for each feature of whether that feature is an unordered feature. - * TODO: Is Array[Boolean] any faster? - */ - def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** * Total number of elements stored in this aggregator */ @@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator( * Pre-compute feature offset for use with [[featureUpdate]]. * For ordered features only. */ - def getFeatureOffset(featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - featureOffsets(featureIndex) - } + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) /** * Pre-compute feature offset for use with [[featureUpdate]]. * For unordered features only. */ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - require(isUnordered(featureIndex), - s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") val baseOffset = featureOffsets(featureIndex) (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 212dce25236e0..772c02670e541 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -82,7 +83,7 @@ private[tree] class DecisionTreeMetadata( } -private[tree] object DecisionTreeMetadata { +private[tree] object DecisionTreeMetadata extends Logging { /** * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. @@ -103,6 +104,10 @@ private[tree] object DecisionTreeMetadata { } val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } // We check the number of bins here against maxPossibleBins. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index d8476b5cd7bc7..004838ee5ba0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,12 +17,15 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.annotation.DeveloperApi + /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ -private[tree] class Predict( +@DeveloperApi +class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala index 4d66d6d81caa5..6a22e2abe59bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -82,9 +82,9 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext */ override def toString: String = algo match { case Classification => - s"RandomForestModel classifier with $numTrees trees" + s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes" case Regression => - s"RandomForestModel regressor with $numTrees trees" + s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes" case _ => throw new IllegalArgumentException( s"RandomForestModel given unknown algo parameter: $algo.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 20d372dc1d3ca..fb44ceb0f57ee 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -173,6 +173,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)) + arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) + val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, + featureSubsetStrategy = "sqrt", seed = 12345) + RandomForestSuite.validateClassifier(model, arr, 1.0) + } + } object RandomForestSuite { From f406a8391825d8866110f29a0d656c82cd064520 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 18 Oct 2014 12:33:20 -0700 Subject: [PATCH 312/315] SPARK-3926 [CORE] Result of JavaRDD.collectAsMap() is not Serializable Make JavaPairRDD.collectAsMap result Serializable since Java Maps generally are Author: Sean Owen Closes #2805 from srowen/SPARK-3926 and squashes the following commits: ecb78ee [Sean Owen] Fix conflict between java.io.Serializable and use of Scala's Serializable f4717f9 [Sean Owen] Oops, fix compile problem ae1b36f [Sean Owen] Expand to cover Maps returned from other Java API methods as well 51c26c2 [Sean Owen] Make JavaPairRDD.collectAsMap result Serializable since Java Maps generally are --- .../org/apache/spark/api/java/JavaPairRDD.scala | 12 +++++++----- .../org/apache/spark/api/java/JavaRDDLike.scala | 7 ++++--- .../scala/org/apache/spark/api/java/JavaUtils.scala | 10 ++++++++++ .../scala/org/apache/spark/sql/api/java/Row.scala | 3 ++- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0846225e4f992..c38b96528d037 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} @@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = - mapAsJavaMap(rdd.reduceByKeyLocally(func)) + mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) + def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) /** * :: Experimental :: @@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap) /** * :: Experimental :: @@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -614,7 +615,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. */ - def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) + def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap()) + /** * Pass each value in the key-value pair RDD through a map function without changing the keys; diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 545bc0e9e99ed..c744399483349 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -30,6 +30,7 @@ import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -390,7 +391,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). @@ -399,13 +400,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { timeout: Long, confidence: Double ): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * (Experimental) Approximate version of countByValue(). */ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout).map(mapAsJavaMap) + rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap) /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 22810cb1c662d..b52d0a5028e84 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -19,10 +19,20 @@ package org.apache.spark.api.java import com.google.common.base.Optional +import scala.collection.convert.Wrappers.MapWrapper + private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = option match { case Some(value) => Optional.of(value) case None => Optional.absent() } + + // Workaround for SPARK-3926 / SI-8911 + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = + new SerializableMapWrapper(underlying) + + class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) + extends MapWrapper(underlying) with java.io.Serializable + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index e9d04ce7aae4c..df01411f60a05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions import scala.math.BigDecimal +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -114,7 +115,7 @@ object Row { // they are actually accessed. case row: ScalaRow => new Row(row) case map: scala.collection.Map[_, _] => - JavaConversions.mapAsJavaMap( + mapAsSerializableJavaMap( map.map { case (key, value) => (toJavaValue(key), toJavaValue(value)) } From 05db2da7dc256822cdb602c4821cbb9fb84dac98 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 18 Oct 2014 19:14:48 -0700 Subject: [PATCH 313/315] [SPARK-3952] [Streaming] [PySpark] add Python examples in Streaming Programming Guide Having Python examples in Streaming Programming Guide. Also add RecoverableNetworkWordCount example. Author: Davies Liu Author: Davies Liu Closes #2808 from davies/pyguide and squashes the following commits: 8d4bec4 [Davies Liu] update readme 26a7e37 [Davies Liu] fix format 3821c4d [Davies Liu] address comments, add missing file 7e4bb8a [Davies Liu] add Python examples in Streaming Programming Guide --- docs/README.md | 3 +- docs/streaming-programming-guide.md | 304 +++++++++++++++++- .../recoverable_network_wordcount.py | 80 +++++ python/docs/pyspark.streaming.rst | 10 + python/pyspark/streaming/dstream.py | 8 +- 5 files changed, 391 insertions(+), 14 deletions(-) create mode 100644 examples/src/main/python/streaming/recoverable_network_wordcount.py create mode 100644 python/docs/pyspark.streaming.rst diff --git a/docs/README.md b/docs/README.md index 0facecdd5f767..d2d58e435d4c4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -25,8 +25,7 @@ installing via the Ruby Gem dependency manager. Since the exact HTML output varies between versions of Jekyll and its dependencies, we list specific versions here in some cases: - $ sudo gem install jekyll -v 1.4.3 - $ sudo gem uninstall kramdown -v 1.4.1 + $ sudo gem install jekyll $ sudo gem install jekyll-redirect-from Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 738309c668387..8bbba88b31978 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -212,6 +212,67 @@ The complete code can be found in the Spark Streaming example [JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
    + +
    +First, we import StreamingContext, which is the main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +# Create a local StreamingContext with two working thread and batch interval of 1 second +sc = SparkContext("local[2]", "NetworkWordCount") +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` + +{% highlight python %} +# Create a DStream that will connect to hostname:port, like localhost:9999 +lines = ssc.socketTextStream("localhost", 9999) +{% endhighlight %} + +This `lines` DStream represents the stream of data that will be received from the data +server. Each record in this DStream is a line of text. Next, we want to split the lines by +space into words. + +{% highlight python %} +# Split each line into words +words = lines.flatMap(lambda line: line.split(" ")) +{% endhighlight %} + +`flatMap` is a one-to-many DStream operation that creates a new DStream by +generating multiple new records from each record in the source DStream. In this case, +each line will be split into multiple words and the stream of words is represented as the +`words` DStream. Next, we want to count these words. + +{% highlight python %} +# Count each word in each batch +pairs = words.map(lambda word: (word, 1)) +wordCounts = pairs.reduceByKey(lambda x, y: x + y) + +# Print the first ten elements of each RDD generated in this DStream to the console +wordCounts.pprint() +{% endhighlight %} + +The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, +1)` pairs, which is then reduced to get the frequency of words in each batch of data. +Finally, `wordCounts.pprint()` will print a few of the counts generated every second. + +Note that when these lines are executed, Spark Streaming only sets up the computation it +will perform when it is started, and no real processing has started yet. To start the processing +after all the transformations have been setup, we finally call + +{% highlight python %} +ssc.start() # Start the computation +ssc.awaitTermination() # Wait for the computation to terminate +{% endhighlight %} + +The complete code can be found in the Spark Streaming example +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py). +
    +
    @@ -236,6 +297,11 @@ $ ./bin/run-example streaming.NetworkWordCount localhost 9999 $ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 {% endhighlight %} +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +{% endhighlight %} +
    @@ -259,8 +325,11 @@ hello world
    Property NameDefaultMeaning
    spark.executor.memory512m - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). -
    spark.executor.extraJavaOptions (none)spark.ui.port 4040 - Port for your application's dashboard, which shows memory and workload data + Port for your application's dashboard, which shows memory and workload data.
    spark.scheduler.revive.interval 1000 - The interval length for the scheduler to revive the worker resource offers to run tasks. - (in milliseconds) + The interval length for the scheduler to revive the worker resource offers to run tasks + (in milliseconds).
    +
    + +
    {% highlight bash %} -# TERMINAL 2: RUNNING NetworkWordCount or JavaNetworkWordCount +# TERMINAL 2: RUNNING NetworkWordCount $ ./bin/run-example streaming.NetworkWordCount localhost 9999 ... @@ -271,6 +340,37 @@ Time: 1357008430000 ms (world,1) ... {% endhighlight %} +
    + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING JavaNetworkWordCount + +$ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 +... +------------------------------------------- +Time: 1357008430000 ms +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
    +
    +{% highlight bash %} +# TERMINAL 2: RUNNING network_wordcount.py + +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +... +------------------------------------------- +Time: 2014-10-14 15:25:21 +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
    +
    @@ -398,9 +498,34 @@ JavaSparkContext sc = ... //existing JavaSparkContext JavaStreamingContext ssc = new JavaStreamingContext(sc, new Duration(1000)); {% endhighlight %} +
    + +A [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) object can be created from a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +sc = SparkContext(master, appName) +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming +in-process (detects the number of cores in the local system). + +The batch interval must be set based on the latency requirements of your application +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +section for more details. +
    After a context is defined, you have to do the follow steps. + 1. Define the input sources. 1. Setup the streaming computations. 1. Start the receiving and procesing of data using `streamingContext.start()`. @@ -483,6 +608,9 @@ methods for creating DStreams from files and Akka actors as input sources.
    streamingContext.fileStream(dataDirectory);
    +
    + streamingContext.textFileStream(dataDirectory) +
    Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that @@ -684,13 +812,30 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} + +
    + +{% highlight python %} +def updateFunction(newValues, runningCount): + if runningCount is None: + runningCount = 0 + return sum(newValues, runningCount) # add the new values with the previous running count to get the new count +{% endhighlight %} + +This is applied on a DStream containing words (say, the `pairs` DStream containing `(word, +1)` pairs in the [earlier example](#a-quick-example)). + +{% highlight python %} +runningCounts = pairs.updateStateByKey(updateFunction) +{% endhighlight %} +
    The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Scala code, take a look at the example -[StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala). +[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). #### Transform Operation {:.no_toc} @@ -732,6 +877,15 @@ JavaPairDStream cleanedDStream = wordCounts.transform( }); {% endhighlight %} + +
    + +{% highlight python %} +spamInfoRDD = sc.pickleFile(...) # RDD containing spam information + +# join data stream with spam information to do data cleaning +cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(...)) +{% endhighlight %}
    @@ -793,6 +947,14 @@ Function2 reduceFunc = new Function2 windowedWordCounts = pairs.reduceByKeyAndWindow(reduceFunc, new Duration(30000), new Duration(10000)); {% endhighlight %} + +
    + +{% highlight python %} +# Reduce last 30 seconds of data, every 10 seconds +windowedWordCounts = pairs.reduceByKeyAndWindow(lambda x, y: x + y, lambda x, y: x - y, 30, 10) +{% endhighlight %} +
    @@ -860,6 +1022,7 @@ see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). +For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) *** @@ -872,9 +1035,12 @@ Currently, the following output operations are defined: - + + This is useful for development and debugging. +
    + PS: called pprint() in Python) + @@ -915,17 +1081,41 @@ For this purpose, a developer may inadvertantly try creating a connection object the Spark driver, but try to use it in a Spark worker to save records in the RDDs. For example (in Scala), +
    +
    + +{% highlight scala %} dstream.foreachRDD(rdd => { val connection = createNewConnection() // executed at the driver rdd.foreach(record => { connection.send(record) // executed at the worker }) }) +{% endhighlight %} + +
    +
    + +{% highlight python %} +def sendRecord(rdd): + connection = createNewConnection() # executed at the driver + rdd.foreach(lambda record: connection.send(record)) + connection.close() + +dstream.foreachRDD(sendRecord) +{% endhighlight %} + +
    +
    - This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. + This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. - However, this can lead to another common mistake - creating a new connection for every record. For example, +
    +
    + +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreach(record => { val connection = createNewConnection() @@ -933,9 +1123,28 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} - Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. +
    +
    + +{% highlight python %} +def sendRecord(record): + connection = createNewConnection() + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) +{% endhighlight %} +
    +
    + + Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. + +
    +
    +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { val connection = createNewConnection() @@ -943,13 +1152,31 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} +
    + +
    +{% highlight python %} +def sendPartition(iter): + connection = createNewConnection() + for record in iter: + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
    +
    - This amortizes the connection creation overheads over many records. + This amortizes the connection creation overheads over many records. - Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches. One can maintain a static pool of connection objects than can be reused as RDDs of multiple batches are pushed to the external system, thus further reducing the overheads. - + +
    +
    +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { // ConnectionPool is a static, lazily initialized pool of connections @@ -958,8 +1185,25 @@ For example (in Scala), ConnectionPool.returnConnection(connection) // return to the pool for future reuse }) }) +{% endhighlight %} +
    - Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. +
    +{% highlight python %} +def sendPartition(iter): + # ConnectionPool is a static, lazily initialized pool of connections + connection = ConnectionPool.getConnection() + for record in iter: + connection.send(record) + # return to the pool for future reuse + ConnectionPool.returnConnection(connection) + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
    +
    + +Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. ##### Other points to remember: @@ -1376,6 +1620,44 @@ You can also explicitly create a `JavaStreamingContext` from the checkpoint data the computation by using `new JavaStreamingContext(checkpointDirectory)`. +
    + +This behavior is made simple by using `StreamingContext.getOrCreate`. This is used as follows. + +{% highlight python %} +# Function to create and setup a new StreamingContext +def functionToCreateContext(): + sc = SparkContext(...) # new context + ssc = new StreamingContext(...) + lines = ssc.socketTextStream(...) # create DStreams + ... + ssc.checkpoint(checkpointDirectory) # set checkpoint directory + return ssc + +# Get StreamingContext from checkpoint data or create a new one +context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext) + +# Do additional setup on context that needs to be done, +# irrespective of whether it is being started or restarted +context. ... + +# Start the context +context.start() +context.awaitTermination() +{% endhighlight %} + +If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. +If the directory does not exist (i.e., running for the first time), +then the function `functionToCreateContext` will be called to create a new +context and set up the DStreams. See the Python example +[recoverable_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). +This example appends the word counts of network data into a file. + +You can also explicitly create a `StreamingContext` from the checkpoint data and start the + computation by using `StreamingContext.getOrCreate(checkpointDirectory, None)`. + +
    + **Note**: If Spark Streaming and/or the Spark Streaming program is recompiled, @@ -1572,7 +1854,11 @@ package and renamed for better clarity. [TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html), [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) + - Python docs + * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) + * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) * More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming) and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming) + and [Python] ({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming) * [Paper](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) and [video](http://youtu.be/g171ndOHgJ0) describing Spark Streaming. diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py new file mode 100644 index 0000000000000..fc6827c82bf9b --- /dev/null +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in text encoded with UTF8 received from the network every second. + + Usage: recoverable_network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive + data. directory to HDFS-compatible file system which checkpoint data + file to which the word counts will be appended + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/recoverable_network_wordcount.py \ + localhost 9999 ~/checkpoint/ ~/out` + + If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + the checkpoint data. +""" + +import os +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + + +def createContext(host, port, outputPath): + # If you do not see this printed, that means the StreamingContext has been loaded + # from the new checkpoint + print "Creating new context" + if os.path.exists(outputPath): + os.remove(outputPath) + sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") + ssc = StreamingContext(sc, 1) + + # Create a socket stream on target ip:port and count the + # words in input stream of \n delimited text (eg. generated by 'nc') + lines = ssc.socketTextStream(host, port) + words = lines.flatMap(lambda line: line.split(" ")) + wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + + def echo(time, rdd): + counts = "Counts at time %s %s" % (time, rdd.collect()) + print counts + print "Appending to " + os.path.abspath(outputPath) + with open(outputPath, 'a') as f: + f.write(counts + "\n") + + wordCounts.foreachRDD(echo) + return ssc + +if __name__ == "__main__": + if len(sys.argv) != 5: + print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ + " " + exit(-1) + host, port, checkpoint, output = sys.argv[1:] + ssc = StreamingContext.getOrCreate(checkpoint, + lambda: createContext(host, int(port), output)) + ssc.start() + ssc.awaitTermination() diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst new file mode 100644 index 0000000000000..5024d694b668f --- /dev/null +++ b/python/docs/pyspark.streaming.rst @@ -0,0 +1,10 @@ +pyspark.streaming module +================== + +Module contents +--------------- + +.. automodule:: pyspark.streaming + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 5ae5cf07f0137..0826ddc56e844 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -441,9 +441,11 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio if `invReduceFunc` is not None, the reduction is done incrementally using the old window's reduced value : - 1. reduce the new values that entered the window (e.g., adding new counts) - 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - This is more efficient than `invReduceFunc` is None. + + 1. reduce the new values that entered the window (e.g., adding new counts) + + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. @param reduceFunc: associative reduce function @param invReduceFunc: inverse reduce function of `reduceFunc` From 7e63bb49c526c3f872619ae14e4b5273f4c535e9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 19 Oct 2014 00:31:06 -0700 Subject: [PATCH 314/315] [SPARK-2546] Clone JobConf for each task (branch-1.0 / 1.1 backport) This patch attempts to fix SPARK-2546 in `branch-1.0` and `branch-1.1`. The underlying problem is that thread-safety issues in Hadoop Configuration objects may cause Spark tasks to get stuck in infinite loops. The approach taken here is to clone a new copy of the JobConf for each task rather than sharing a single copy between tasks. Note that there are still Configuration thread-safety issues that may affect the driver, but these seem much less likely to occur in practice and will be more complex to fix (see discussion on the SPARK-2546 ticket). This cloning is guarded by a new configuration option (`spark.hadoop.cloneConf`) and is disabled by default in order to avoid unexpected performance regressions for workloads that are unaffected by the Configuration thread-safety issues. Author: Josh Rosen Closes #2684 from JoshRosen/jobconf-fix-backport and squashes the following commits: f14f259 [Josh Rosen] Add configuration option to control cloning of Hadoop JobConf. b562451 [Josh Rosen] Remove unused jobConfCacheKey field. dd25697 [Josh Rosen] [SPARK-2546] [1.0 / 1.1 backport] Clone JobConf for each task. (cherry picked from commit 2cd40db2b3ab5ddcb323fd05c171dbd9025f9e71) Signed-off-by: Josh Rosen Conflicts: core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala --- .../org/apache/spark/rdd/HadoopRDD.scala | 53 +++++++++++++------ docs/configuration.md | 9 ++++ 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8010dd90082f8..775141775e06c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -132,27 +132,47 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() + private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value - if (conf.isInstanceOf[JobConf]) { - // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it. - conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - // getJobConf() has been called previously, so there is already a local cache of the JobConf - // needed by this RDD. - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] - } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456). + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546). This problem occurs + // somewhat rarely because most jobs treat the configuration as though it's immutable. One + // solution, implemented here, is to clone the Configuration object. Unfortunately, this + // clone can be very expensive. To avoid unexpected performance regressions for workloads and + // Hadoop versions that do not suffer from these thread-safety issues, this cloning is + // disabled by default. HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.map(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + if (!conf.isInstanceOf[JobConf]) { + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + } newJobConf } + } else { + if (conf.isInstanceOf[JobConf]) { + logDebug("Re-using user-broadcasted JobConf") + conf.asInstanceOf[JobConf] + } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { + logDebug("Re-using cached JobConf") + HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] + } else { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the + // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. + // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } + } } } @@ -276,7 +296,10 @@ class HadoopRDD[K, V]( } private[spark] object HadoopRDD extends Logging { - /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new JobConf() or new Configuration(). + */ val CONFIGURATION_INSTANTIATION_LOCK = new Object() /** diff --git a/docs/configuration.md b/docs/configuration.md index f0204c640bc89..96fa1377ec399 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -619,6 +619,15 @@ Apart from these, the following properties are also available, and may be useful output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + + + + + From d1966f3a8bafdcef87d10ef9db5976cf89faee4b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 19 Oct 2014 20:02:31 -0700 Subject: [PATCH 315/315] [SPARK-3902] [SPARK-3590] Stabilize AsynRDDActions and add Java API This PR adds a Java API for AsyncRDDActions and promotes the API from `Experimental` to stable. Author: Josh Rosen Author: Josh Rosen Closes #2760 from JoshRosen/async-rdd-actions-in-java and squashes the following commits: 0d45fbc [Josh Rosen] Whitespace fix. ad3ae53 [Josh Rosen] Merge remote-tracking branch 'origin/master' into async-rdd-actions-in-java c0153a5 [Josh Rosen] Remove unused variable. e8e2867 [Josh Rosen] Updates based on Marcelo's review feedback 7a1417f [Josh Rosen] Removed unnecessary java.util import. 6f8f6ac [Josh Rosen] Fix import ordering. ff28e49 [Josh Rosen] Add MiMa excludes and fix a scalastyle error. 346e46e [Josh Rosen] [SPARK-3902] Stabilize AsyncRDDActions; add Java API. --- .../spark/api/java/JavaFutureAction.java | 33 +++++++ .../scala/org/apache/spark/FutureAction.scala | 86 ++++++++++++++--- .../apache/spark/api/java/JavaRDDLike.scala | 53 ++++++++--- .../apache/spark/rdd/AsyncRDDActions.scala | 3 - .../java/org/apache/spark/JavaAPISuite.java | 93 ++++++++++++++++++- project/MimaExcludes.scala | 13 ++- 6 files changed, 246 insertions(+), 35 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java diff --git a/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java new file mode 100644 index 0000000000000..0ad189633e427 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + + +import java.util.List; +import java.util.concurrent.Future; + +public interface JavaFutureAction extends Future { + + /** + * Returns the job IDs run by the underlying async operation. + * + * This returns the current snapshot of the job list. Certain operations may run multiple + * jobs, so multiple calls to this method may return different lists. + */ + List jobIds(); +} diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e8f761eaa5799..d5c8f9d76c476 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -17,20 +17,21 @@ package org.apache.spark -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.Try +import java.util.Collections +import java.util.concurrent.TimeUnit -import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.{Failure, Try} + /** - * :: Experimental :: * A future for the result of an action to support cancellation. This is an extension of the * Scala Future interface to support cancellation. */ -@Experimental trait FutureAction[T] extends Future[T] { // Note that we redefine methods of the Future trait here explicitly so we can specify a different // documentation (with reference to the word "action"). @@ -69,6 +70,11 @@ trait FutureAction[T] extends Future[T] { */ override def isCompleted: Boolean + /** + * Returns whether the action has been cancelled. + */ + def isCancelled: Boolean + /** * The value of this Future. * @@ -96,15 +102,16 @@ trait FutureAction[T] extends Future[T] { /** - * :: Experimental :: * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ -@Experimental class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { + @volatile private var _cancelled: Boolean = false + override def cancel() { + _cancelled = true jobWaiter.cancel() } @@ -143,6 +150,8 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished + + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { if (jobWaiter.jobFinished) { @@ -164,12 +173,10 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: /** - * :: Experimental :: * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the * action thread if it is being blocked by a job. */ -@Experimental class ComplexFutureAction[T] extends FutureAction[T] { // Pointer to the thread that is executing the action. It is set when the action is run. @@ -222,7 +229,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { - if (!cancelled) { + if (!isCancelled) { rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) } else { throw new SparkException("Action has been cancelled") @@ -243,10 +250,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { } } - /** - * Returns whether the promise has been cancelled. - */ - def cancelled: Boolean = _cancelled + override def isCancelled: Boolean = _cancelled @throws(classOf[InterruptedException]) @throws(classOf[scala.concurrent.TimeoutException]) @@ -271,3 +275,55 @@ class ComplexFutureAction[T] extends FutureAction[T] { def jobIds = jobs } + +private[spark] +class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) + extends JavaFutureAction[T] { + + import scala.collection.JavaConverters._ + + override def isCancelled: Boolean = futureAction.isCancelled + + override def isDone: Boolean = { + // According to java.util.Future's Javadoc, this returns True if the task was completed, + // whether that completion was due to successful execution, an exception, or a cancellation. + futureAction.isCancelled || futureAction.isCompleted + } + + override def jobIds(): java.util.List[java.lang.Integer] = { + Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava) + } + + private def getImpl(timeout: Duration): T = { + // This will throw TimeoutException on timeout: + Await.ready(futureAction, timeout) + futureAction.value.get match { + case scala.util.Success(value) => converter(value) + case Failure(exception) => + if (isCancelled) { + throw new CancellationException("Job cancelled").initCause(exception) + } else { + // java.util.Future.get() wraps exceptions in ExecutionException + throw new ExecutionException("Exception thrown by job", exception) + } + } + } + + override def get(): T = getImpl(Duration.Inf) + + override def get(timeout: Long, unit: TimeUnit): T = + getImpl(Duration.fromNanos(unit.toNanos(timeout))) + + override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized { + if (isDone) { + // According to java.util.Future's Javadoc, this should return false if the task is completed. + false + } else { + // We're limited in terms of the semantics we can provide here; our cancellation is + // asynchronous and doesn't provide a mechanism to not cancel if the job is running. + futureAction.cancel() + true + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index c744399483349..efb8978f7ce12 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,12 +21,14 @@ import java.util.{Comparator, List => JList, Iterator => JIterator} import java.lang.{Iterable => JIterable, Long => JLong} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext} +import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag @@ -294,8 +296,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to all elements of this RDD. */ def foreach(f: VoidFunction[T]) { - val cleanF = rdd.context.clean((x: T) => f.call(x)) - rdd.foreach(cleanF) + rdd.foreach(x => f.call(x)) } /** @@ -576,16 +577,44 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def name(): String = rdd.name /** - * :: Experimental :: - * The asynchronous version of the foreach action. - * - * @param f the function to apply to all the elements of the RDD - * @return a FutureAction for the action + * The asynchronous version of `count`, which returns a + * future for counting the number of elements in this RDD. */ - @Experimental - def foreachAsync(f: VoidFunction[T]): FutureAction[Unit] = { - import org.apache.spark.SparkContext._ - rdd.foreachAsync(x => f.call(x)) + def countAsync(): JavaFutureAction[JLong] = { + new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf) + } + + /** + * The asynchronous version of `collect`, which returns a future for + * retrieving an array containing all of the elements in this RDD. + */ + def collectAsync(): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava) + } + + /** + * The asynchronous version of the `take` action, which returns a + * future for retrieving the first `num` elements of this RDD. + */ + def takeAsync(num: Int): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava) } + /** + * The asynchronous version of the `foreach` action, which + * applies a function f to all the elements of this RDD. + */ + def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } + + /** + * The asynchronous version of the `foreachPartition` action, which + * applies a function f to each partition of this RDD. + */ + def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ede5568493cc0..9f9f10b7ebc3a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -24,14 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} -import org.apache.spark.annotation.Experimental /** - * :: Experimental :: * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -@Experimental class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8fa822ae4bd8..3190148fb5f43 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -20,6 +20,7 @@ import java.io.*; import java.net.URI; import java.util.*; +import java.util.concurrent.*; import scala.Tuple2; import scala.Tuple3; @@ -29,6 +30,7 @@ import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.base.Throwables; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; @@ -43,10 +45,7 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.partial.BoundedDouble; @@ -1308,6 +1307,92 @@ public void collectUnderlyingScalaRDD() { Assert.assertEquals(data.size(), collected.length); } + private static final class BuggyMapFunction implements Function { + + @Override + public T call(T x) throws Exception { + throw new IllegalStateException("Custom exception!"); + } + } + + @Test + public void collectAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction> future = rdd.collectAsync(); + List result = future.get(); + Assert.assertEquals(data, result); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void foreachAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.foreachAsync( + new VoidFunction() { + @Override + public void call(Integer integer) throws Exception { + // intentionally left blank. + } + } + ); + future.get(); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void countAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.countAsync(); + long count = future.get(); + Assert.assertEquals(data.size(), count); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void testAsyncActionCancellation() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { + @Override + public void call(Integer integer) throws Exception { + Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. + } + }); + future.cancel(true); + Assert.assertTrue(future.isCancelled()); + Assert.assertTrue(future.isDone()); + try { + future.get(2000, TimeUnit.MILLISECONDS); + Assert.fail("Expected future.get() for cancelled job to throw CancellationException"); + } catch (CancellationException ignored) { + // pass + } + } + + @Test + public void testAsyncActionErrorWrapping() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync(); + try { + future.get(2, TimeUnit.SECONDS); + Assert.fail("Expected future.get() for failed job to throw ExcecutionException"); + } catch (ExecutionException ee) { + Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!")); + } + Assert.assertTrue(future.isDone()); + } + + /** * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, * since that's the only artifact where Guava classes have been relocated. diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 350aad47735e4..c58666af84f24 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,7 +54,18 @@ object MimaExcludes { // TaskContext was promoted to Abstract class ProblemFilters.exclude[AbstractClassProblem]( "org.apache.spark.TaskContext") - + ) ++ Seq( + // Adding new methods to the JavaRDDLike trait: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.takeAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.collectAsync") ) case v if v.startsWith("1.1") =>
    Output OperationMeaning
    print() print() Prints first ten elements of every batch of data in a DStream on the driver. - This is useful for development and debugging.
    saveAsObjectFiles(prefix, [suffix])
    spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This + option should be enabled to work around Configuration thread-safety issues (see + SPARK-2546 for more details). + This is disabled by default in order to avoid unexpected performance regressions for jobs that + are not affected by these issues.
    spark.executor.heartbeatInterval 10000