diff --git a/.gitignore b/.gitignore index 07524bc429e92..8ecf536e79a5f 100644 --- a/.gitignore +++ b/.gitignore @@ -60,7 +60,6 @@ dev/create-release/*final spark-*-bin-*.tgz unit-tests.log /lib/ -ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index beacc39500aaa..34be7f0ebd752 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -130,6 +130,7 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "hash", "cume_dist", "date_add", "date_format", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index df36bc869acb4..9bb7876b384ce 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -340,6 +340,26 @@ setMethod("crc32", column(jc) }) +#' hash +#' +#' Calculates the hash code of given columns, and returns the result as a int column. +#' +#' @rdname hash +#' @name hash +#' @family misc_funcs +#' @export +#' @examples \dontrun{hash(df$c)} +setMethod("hash", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "hash", jcols) + column(jc) + }) + #' dayofmonth #' #' Extracts the day of the month as an integer from a given date/timestamp/string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ba6861709754d..5ba68e3a4f378 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -736,6 +736,10 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname hash +#' @export +setGeneric("hash", function(x, ...) { standardGeneric("hash") }) + #' @rdname cume_dist #' @export setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index eaf60beda3473..97625b94a0e23 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -922,7 +922,7 @@ test_that("column functions", { c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) - c3 <- cosh(c) + count(c) + crc32(c) + exp(c) + c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 94719a4572ef6..7de9df1e489fb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -77,7 +77,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi This implementation is non-blocking, asynchronously handling the results of each job and triggering the next job using callbacks on futures. */ - def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] = if (results.size >= num || partsScanned >= totalParts) { Future.successful(results.toSeq) } else { @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) @@ -109,13 +109,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - job.flatMap {_ => + job.flatMap { _ => buf.foreach(results ++= _.take(num - results.size)) continue(partsScanned + p.size) } } - new ComplexFutureAction[Seq[T]](continue(0L)(_)) + new ComplexFutureAction[Seq[T]](continue(0)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e25657cc109be..de7102f5b6245 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag]( } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length - var partsScanned = 0L + var partsScanned = 0 while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. @@ -1209,7 +1209,7 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 24acbed4d7258..ef2ed445005d3 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -482,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(nums.take(501) === (1 to 501).toArray) assert(nums.take(999) === (1 to 999).toArray) assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.parallelize(1 to 2, 2) + assert(nums.take(2147483638).size === 2) + assert(nums.takeAsync(2147483638).get.size === 2) } test("top with predefined ordering") { diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b0a3374becc6a..d404939d1caee 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -64,9 +64,6 @@ git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG -# TODO: It would be nice to do some verifications here -# i.e. check whether ec2 scripts have the new version - # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 7f152b7f53559..5d0ac16b3b0a1 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -159,7 +159,6 @@ def get_commits(tag): "build": CORE_COMPONENT, "deploy": CORE_COMPONENT, "documentation": CORE_COMPONENT, - "ec2": "EC2", "examples": CORE_COMPONENT, "graphx": "GraphX", "input/output": CORE_COMPONENT, diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index e4373f79f7922..cd3ff293502ae 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -84,13 +84,13 @@ hadoop-yarn-server-web-proxy-2.2.0.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 7478181406d07..0985089ccea61 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -79,13 +79,13 @@ hadoop-yarn-server-web-proxy-2.3.0.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index faffb8bf398a5..50f062601c02b 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -79,13 +79,13 @@ hadoop-yarn-server-web-proxy-2.4.0.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e703c7acd3876..2b6ca983ad65e 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -85,13 +85,13 @@ htrace-core-3.0.4.jar httpclient-4.3.2.jar httpcore-4.3.2.jar ivy-2.4.0.jar -jackson-annotations-2.4.4.jar -jackson-core-2.4.4.jar +jackson-annotations-2.5.3.jar +jackson-core-2.5.3.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.4.4.jar +jackson-databind-2.5.3.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.10-2.4.4.jar +jackson-module-scala_2.10-2.5.3.jar jackson-xc-1.9.13.jar janino-2.7.8.jar jansi-1.4.jar diff --git a/dev/lint-python b/dev/lint-python index 0b97213ae3dff..1765a07d2f22b 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 47cd600bd18a4..1fc6596164124 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -406,15 +406,6 @@ def contains_file(self, filename): should_run_build_tests=True ) -ec2 = Module( - name="ec2", - dependencies=[], - source_file_regexes=[ - "ec2/", - ] -) - - yarn = Module( name="yarn", dependencies=[], diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 424ce6ad7663c..def87aa4087e3 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -70,19 +70,10 @@ $MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /de # Generate manifests for each Hadoop profile: for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do echo "Performing Maven install for $HADOOP_PROFILE" - $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar install:install -q \ - -pl '!assembly' \ - -pl '!examples' \ - -pl '!external/flume-assembly' \ - -pl '!external/kafka-assembly' \ - -pl '!external/twitter' \ - -pl '!external/flume' \ - -pl '!external/mqtt' \ - -pl '!external/mqtt-assembly' \ - -pl '!external/zeromq' \ - -pl '!external/kafka' \ - -pl '!tags' \ - -DskipTests + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install -q + + echo "Performing Maven validate for $HADOOP_PROFILE" + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE validate -q echo "Generating dependency manifest for $HADOOP_PROFILE" mkdir -p dev/pr-deps diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 62d75eff71057..d493f62f0e578 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -98,8 +98,6 @@
  • Spark Standalone
  • Mesos
  • YARN
  • -
  • -
  • Amazon EC2
  • diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index faaf154d243f5..2810112f5294e 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -53,8 +53,6 @@ The system currently supports three cluster managers: and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone -cluster on Amazon EC2. # Submitting Applications diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md deleted file mode 100644 index 7f60f82b966fe..0000000000000 --- a/docs/ec2-scripts.md +++ /dev/null @@ -1,192 +0,0 @@ ---- -layout: global -title: Running Spark on EC2 ---- - -The `spark-ec2` script, located in Spark's `ec2` directory, allows you -to launch, manage and shut down Spark clusters on Amazon EC2. It automatically -sets up Spark and HDFS on the cluster for you. This guide describes -how to use `spark-ec2` to launch clusters, how to run jobs on them, and how -to shut them down. It assumes you've already signed up for an EC2 account -on the [Amazon Web Services site](http://aws.amazon.com/). - -`spark-ec2` is designed to manage multiple named clusters. You can -launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster is -identified by placing its machines into EC2 security groups whose names -are derived from the name of the cluster. For example, a cluster named -`test` will contain a master node in a security group called -`test-master`, and a number of slave nodes in a security group called -`test-slaves`. The `spark-ec2` script will create these security groups -for you based on the cluster name you request. You can also use them to -identify machines belonging to each cluster in the Amazon EC2 Console. - - -# Before You Start - -- Create an Amazon EC2 key pair for yourself. This can be done by - logging into your Amazon Web Services account through the [AWS - console](http://aws.amazon.com/console/), clicking Key Pairs on the - left sidebar, and creating and downloading a key. Make sure that you - set the permissions for the private key file to `600` (i.e. only you - can read and write it) so that `ssh` will work. -- Whenever you want to use the `spark-ec2` script, set the environment - variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to your - Amazon EC2 access key ID and secret access key. These can be - obtained from the [AWS homepage](http://aws.amazon.com/) by clicking - Account \> Security Credentials \> Access Credentials. - -# Launching a Cluster - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run - `./spark-ec2 -k -i -s launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster - ``` - -- After everything launches, check that the cluster scheduler is up and sees - all the slaves by going to its web UI, which will be printed at the end of - the script (typically `http://:8080`). - -You can also run `./spark-ec2 --help` to see more usage options. The -following options are worth pointing out: - -- `--instance-type=` can be used to specify an EC2 -instance type to use. For now, the script only supports 64-bit instance -types, and the default type is `m1.large` (which has 2 cores and 7.5 GB -RAM). Refer to the Amazon pages about [EC2 instance -types](http://aws.amazon.com/ec2/instance-types) and [EC2 -pricing](http://aws.amazon.com/ec2/#pricing) for information about other -instance types. -- `--region=` specifies an EC2 region in which to launch -instances. The default region is `us-east-1`. -- `--zone=` can be used to specify an EC2 availability zone -to launch instances in. Sometimes, you will get an error because there -is not enough capacity in one zone, and you should try to launch in -another. -- `--ebs-vol-size=` will attach an EBS volume with a given amount - of space to each node so that you can have a persistent HDFS cluster - on your nodes across cluster restarts (see below). -- `--spot-price=` will launch the worker nodes as - [Spot Instances](http://aws.amazon.com/ec2/spot-instances/), - bidding for the given maximum price (in dollars). -- `--spark-version=` will pre-load the cluster with the - specified version of Spark. The `` can be a version number - (e.g. "0.7.3") or a specific git hash. By default, a recent - version will be used. -- `--spark-git-repo=` will let you run a custom version of - Spark that is built from the given git repository. By default, the - [Apache Github mirror](https://github.com/apache/spark) will be used. - When using a custom Spark version, `--spark-version` must be set to git - commit hash, such as 317e114, instead of a version number. -- If one of your launches fails due to e.g. not having the right -permissions on your private key file, you can run `launch` with the -`--resume` option to restart the setup process on an existing cluster. - -# Launching a Cluster in a VPC - -- Run - `./spark-ec2 -k -i -s --vpc-id= --subnet-id= launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), `` is the name of your VPC, `` is the - name of your subnet, and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --vpc-id=vpc-a28d24c7 --subnet-id=subnet-4eb27b39 --spark-version=1.1.0 launch my-spark-cluster - ``` - -# Running Applications - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 -k -i login ` to - SSH into the cluster, where `` and `` are as - above. (This is just for convenience; you could also use - the EC2 console.) -- To deploy code or data within your cluster, you can log in and use the - provided script `~/spark-ec2/copy-dir`, which, - given a directory path, RSYNCs it to the same location on all the slaves. -- If your application needs to access large datasets, the fastest way to do - that is to load them from Amazon S3 or an Amazon EBS device into an - instance of the Hadoop Distributed File System (HDFS) on your nodes. - The `spark-ec2` script already sets up a HDFS instance for you. It's - installed in `/root/ephemeral-hdfs`, and can be accessed using the - `bin/hadoop` script in that directory. Note that the data in this - HDFS goes away when you stop and restart a machine. -- There is also a *persistent HDFS* instance in - `/root/persistent-hdfs` that will keep data across cluster restarts. - Typically each node has relatively little space of persistent data - (about 3 GB), but you can use the `--ebs-vol-size` option to - `spark-ec2` to attach a persistent EBS volume to each node for - storing the persistent HDFS. -- Finally, if you get errors while running your application, look at the slave's logs - for that application inside of the scheduler work directory (/root/spark/work). You can - also view the status of the cluster using the web UI: `http://:8080`. - -# Configuration - -You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such -as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to -do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master, -then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers. - -The [configuration guide](configuration.html) describes the available configuration options. - -# Terminating a Cluster - -***Note that there is no way to recover data on EC2 nodes after shutting -them down! Make sure you have copied everything important off the nodes -before stopping them.*** - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 destroy `. - -# Pausing and Restarting Clusters - -The `spark-ec2` script also supports pausing a cluster. In this case, -the VMs are stopped but not terminated, so they -***lose all data on ephemeral disks*** but keep the data in their -root partitions and their `persistent-hdfs`. Stopped machines will not -cost you any EC2 cycles, but ***will*** continue to cost money for EBS -storage. - -- To stop one of your clusters, go into the `ec2` directory and run -`./spark-ec2 --region= stop `. -- To restart it later, run -`./spark-ec2 -i --region= start `. -- To ultimately destroy the cluster and stop consuming EBS space, run -`./spark-ec2 --region= destroy ` as described in the previous -section. - -# Limitations - -- Support for "cluster compute" nodes is limited -- there's no way to specify a - locality group. However, you can launch slave nodes in your - `-slaves` group manually and then use `spark-ec2 launch - --resume` to start a cluster with them. - -If you have a patch or suggestion for one of these limitations, feel free to -[contribute](contributing-to-spark.html) it! - -# Accessing Data in S3 - -Spark's file interface allows it to process data in Amazon S3 using the same URI formats that are supported for Hadoop. You can specify a path in S3 as input through a URI of the form `s3n:///path`. To provide AWS credentials for S3 access, launch the Spark cluster with the option `--copy-aws-credentials`. Full instructions on S3 access using the Hadoop input libraries can be found on the [Hadoop S3 page](http://wiki.apache.org/hadoop/AmazonS3). - -In addition to using a single input file, you can also use a directory of files as input by simply giving the path to the directory. diff --git a/docs/index.md b/docs/index.md index ae26f97c86c21..9dfc52a2bdc9b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,7 +64,7 @@ To run Spark interactively in a R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] Example applications are also provided in R. For example, - + ./bin/spark-submit examples/src/main/r/dataframe.R # Launching on a Cluster @@ -73,7 +73,6 @@ The Spark [cluster mode overview](cluster-overview.html) explains the key concep Spark can run both by itself, or over several existing cluster managers. It currently provides several options for deployment: -* [Amazon EC2](ec2-scripts.html): our EC2 scripts let you launch a cluster in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): simplest way to deploy Spark on a private cluster * [Apache Mesos](running-on-mesos.html) * [Hadoop YARN](running-on-yarn.html) @@ -103,7 +102,7 @@ options for deployment: * [Cluster Overview](cluster-overview.html): overview of concepts and components when running on a cluster * [Submitting Applications](submitting-applications.html): packaging and deploying applications * Deployment modes: - * [Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes + * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using [Apache Mesos](http://mesos.apache.org) diff --git a/ec2/README b/ec2/README deleted file mode 100644 index 72434f24bf98d..0000000000000 --- a/ec2/README +++ /dev/null @@ -1,4 +0,0 @@ -This folder contains a script, spark-ec2, for launching Spark clusters on -Amazon EC2. Usage instructions are available online at: - -http://spark.apache.org/docs/latest/ec2-scripts.html diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh deleted file mode 100644 index 4f3e8da809f7f..0000000000000 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# These variables are automatically filled in by the spark-ec2 script. -export MASTERS="{{master_list}}" -export SLAVES="{{slave_list}}" -export HDFS_DATA_DIRS="{{hdfs_data_dirs}}" -export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" -export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" -export MODULES="{{modules}}" -export SPARK_VERSION="{{spark_version}}" -export TACHYON_VERSION="{{tachyon_version}}" -export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" -export SWAP_MB="{{swap}}" -export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" -export SPARK_MASTER_OPTS="{{spark_master_opts}}" -export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" -export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 deleted file mode 100755 index 26e7d22655694..0000000000000 --- a/ec2/spark-ec2 +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Preserve the user's CWD so that relative paths are passed correctly to -#+ the underlying Python script. -SPARK_EC2_DIR="$(dirname "$0")" - -python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py deleted file mode 100755 index 19d5980560fef..0000000000000 --- a/ec2/spark_ec2.py +++ /dev/null @@ -1,1530 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import division, print_function, with_statement - -import codecs -import hashlib -import itertools -import logging -import os -import os.path -import pipes -import random -import shutil -import string -from stat import S_IRUSR -import subprocess -import sys -import tarfile -import tempfile -import textwrap -import time -import warnings -from datetime import datetime -from optparse import OptionParser -from sys import stderr - -if sys.version < "3": - from urllib2 import urlopen, Request, HTTPError -else: - from urllib.request import urlopen, Request - from urllib.error import HTTPError - raw_input = input - xrange = range - -SPARK_EC2_VERSION = "1.6.0" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) - -VALID_SPARK_VERSIONS = set([ - "0.7.3", - "0.8.0", - "0.8.1", - "0.9.0", - "0.9.1", - "0.9.2", - "1.0.0", - "1.0.1", - "1.0.2", - "1.1.0", - "1.1.1", - "1.2.0", - "1.2.1", - "1.3.0", - "1.3.1", - "1.4.0", - "1.4.1", - "1.5.0", - "1.5.1", - "1.5.2", - "1.6.0", -]) - -SPARK_TACHYON_MAP = { - "1.0.0": "0.4.1", - "1.0.1": "0.4.1", - "1.0.2": "0.4.1", - "1.1.0": "0.5.0", - "1.1.1": "0.5.0", - "1.2.0": "0.5.0", - "1.2.1": "0.5.0", - "1.3.0": "0.5.0", - "1.3.1": "0.5.0", - "1.4.0": "0.6.4", - "1.4.1": "0.6.4", - "1.5.0": "0.7.1", - "1.5.1": "0.7.1", - "1.5.2": "0.7.1", - "1.6.0": "0.8.2", -} - -DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION -DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" - -# Default location to get the spark-ec2 scripts (and ami-list) from -DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.5" - - -def setup_external_libs(libs): - """ - Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH. - """ - PYPI_URL_PREFIX = "https://pypi.python.org/packages/source" - SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") - - if not os.path.exists(SPARK_EC2_LIB_DIR): - print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( - path=SPARK_EC2_LIB_DIR - )) - print("This should be a one-time operation.") - os.mkdir(SPARK_EC2_LIB_DIR) - - for lib in libs: - versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"]) - lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name) - - if not os.path.isdir(lib_dir): - tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") - print(" - Downloading {lib}...".format(lib=lib["name"])) - download_stream = urlopen( - "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( - prefix=PYPI_URL_PREFIX, - first_letter=lib["name"][:1], - lib_name=lib["name"], - lib_version=lib["version"] - ) - ) - with open(tgz_file_path, "wb") as tgz_file: - tgz_file.write(download_stream.read()) - with open(tgz_file_path, "rb") as tar: - if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: - print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) - sys.exit(1) - tar = tarfile.open(tgz_file_path) - tar.extractall(path=SPARK_EC2_LIB_DIR) - tar.close() - os.remove(tgz_file_path) - print(" - Finished downloading {lib}.".format(lib=lib["name"])) - sys.path.insert(1, lib_dir) - - -# Only PyPI libraries are supported. -external_libs = [ - { - "name": "boto", - "version": "2.34.0", - "md5": "5556223d2d0cc4d06dd4829e671dcecd" - } -] - -setup_external_libs(external_libs) - -import boto -from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType -from boto import ec2 - - -class UsageError(Exception): - pass - - -# Configure and parse our command-line arguments -def parse_args(): - parser = OptionParser( - prog="spark-ec2", - version="%prog {v}".format(v=SPARK_EC2_VERSION), - usage="%prog [options] \n\n" - + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") - - parser.add_option( - "-s", "--slaves", type="int", default=1, - help="Number of slaves to launch (default: %default)") - parser.add_option( - "-w", "--wait", type="int", - help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") - parser.add_option( - "-k", "--key-pair", - help="Key pair to use on instances") - parser.add_option( - "-i", "--identity-file", - help="SSH private key file to use for logging into instances") - parser.add_option( - "-p", "--profile", default=None, - help="If you have multiple profiles (AWS or boto config), you can configure " + - "additional, named profiles by using this option (default: %default)") - parser.add_option( - "-t", "--instance-type", default="m1.large", - help="Type of instance to launch (default: %default). " + - "WARNING: must be 64-bit; small instances won't work") - parser.add_option( - "-m", "--master-instance-type", default="", - help="Master instance type (leave empty for same as instance-type)") - parser.add_option( - "-r", "--region", default="us-east-1", - help="EC2 region used to launch instances in, or to find them in (default: %default)") - parser.add_option( - "-z", "--zone", default="", - help="Availability zone to launch instances in, or 'all' to spread " + - "slaves across multiple (an additional $0.01/Gb for bandwidth" + - "between zones applies) (default: a single zone chosen at random)") - parser.add_option( - "-a", "--ami", - help="Amazon Machine Image ID to use") - parser.add_option( - "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, - help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") - parser.add_option( - "--spark-git-repo", - default=DEFAULT_SPARK_GITHUB_REPO, - help="Github repo from which to checkout supplied commit hash (default: %default)") - parser.add_option( - "--spark-ec2-git-repo", - default=DEFAULT_SPARK_EC2_GITHUB_REPO, - help="Github repo from which to checkout spark-ec2 (default: %default)") - parser.add_option( - "--spark-ec2-git-branch", - default=DEFAULT_SPARK_EC2_BRANCH, - help="Github repo branch of spark-ec2 to use (default: %default)") - parser.add_option( - "--deploy-root-dir", - default=None, - help="A directory to copy into / on the first master. " + - "Must be absolute. Note that a trailing slash is handled as per rsync: " + - "If you omit it, the last directory of the --deploy-root-dir path will be created " + - "in / before copying its contents. If you append the trailing slash, " + - "the directory is not created and its contents are copied directly into /. " + - "(default: %default).") - parser.add_option( - "--hadoop-major-version", default="1", - help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + - "(Hadoop 2.4.0) (default: %default)") - parser.add_option( - "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", - help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + - "the given local address (for use with login)") - parser.add_option( - "--resume", action="store_true", default=False, - help="Resume installation on a previously launched cluster " + - "(for debugging)") - parser.add_option( - "--ebs-vol-size", metavar="SIZE", type="int", default=0, - help="Size (in GB) of each EBS volume.") - parser.add_option( - "--ebs-vol-type", default="standard", - help="EBS volume type (e.g. 'gp2', 'standard').") - parser.add_option( - "--ebs-vol-num", type="int", default=1, - help="Number of EBS volumes to attach to each node as /vol[x]. " + - "The volumes will be deleted when the instances terminate. " + - "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0. " + - "Only support up to 8 EBS volumes.") - parser.add_option( - "--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") - parser.add_option( - "--swap", metavar="SWAP", type="int", default=1024, - help="Swap space to set up per node, in MB (default: %default)") - parser.add_option( - "--spot-price", metavar="PRICE", type="float", - help="If specified, launch slaves as spot instances with the given " + - "maximum price (in dollars)") - parser.add_option( - "--ganglia", action="store_true", default=True, - help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + - "the Ganglia page will be publicly accessible") - parser.add_option( - "--no-ganglia", action="store_false", dest="ganglia", - help="Disable Ganglia monitoring for the cluster") - parser.add_option( - "-u", "--user", default="root", - help="The SSH user you want to connect as (default: %default)") - parser.add_option( - "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") - parser.add_option( - "--use-existing-master", action="store_true", default=False, - help="Launch fresh slaves, but use an existing stopped master if possible") - parser.add_option( - "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + - "is used as Hadoop major version (default: %default)") - parser.add_option( - "--master-opts", type="string", default="", - help="Extra options to give to master through SPARK_MASTER_OPTS variable " + - "(e.g -Dspark.worker.timeout=180)") - parser.add_option( - "--user-data", type="string", default="", - help="Path to a user-data file (most AMIs interpret this as an initialization script)") - parser.add_option( - "--authorized-address", type="string", default="0.0.0.0/0", - help="Address to authorize on created security groups (default: %default)") - parser.add_option( - "--additional-security-group", type="string", default="", - help="Additional security group to place the machines in") - parser.add_option( - "--additional-tags", type="string", default="", - help="Additional tags to set on the machines; tags are comma-separated, while name and " + - "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") - parser.add_option( - "--copy-aws-credentials", action="store_true", default=False, - help="Add AWS credentials to hadoop configuration to allow Spark to access S3") - parser.add_option( - "--subnet-id", default=None, - help="VPC subnet to launch instances in") - parser.add_option( - "--vpc-id", default=None, - help="VPC to launch instances in") - parser.add_option( - "--private-ips", action="store_true", default=False, - help="Use private IPs for instances rather than public if VPC/subnet " + - "requires that.") - parser.add_option( - "--instance-initiated-shutdown-behavior", default="stop", - choices=["stop", "terminate"], - help="Whether instances should terminate when shut down or just stop") - parser.add_option( - "--instance-profile-name", default=None, - help="IAM profile name to launch instances under") - - (opts, args) = parser.parse_args() - if len(args) != 2: - parser.print_help() - sys.exit(1) - (action, cluster_name) = args - - # Boto config check - # http://boto.cloudhackers.com/en/latest/boto_config_tut.html - home_dir = os.getenv('HOME') - if home_dir is None or not os.path.isfile(home_dir + '/.boto'): - if not os.path.isfile('/etc/boto.cfg'): - # If there is no boto config, check aws credentials - if not os.path.isfile(home_dir + '/.aws/credentials'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) - return (opts, action, cluster_name) - - -# Get the EC2 security group of the given name, creating it if it doesn't exist -def get_or_make_group(conn, name, vpc_id): - groups = conn.get_all_security_groups() - group = [g for g in groups if g.name == name] - if len(group) > 0: - return group[0] - else: - print("Creating security group " + name) - return conn.create_security_group(name, "Spark EC2 group", vpc_id) - - -def get_validate_spark_version(version, repo): - if "." in version: - version = version.replace("v", "") - if version not in VALID_SPARK_VERSIONS: - print("Don't know about Spark version: {v}".format(v=version), file=stderr) - sys.exit(1) - return version - else: - github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = Request(github_commit_url) - request.get_method = lambda: 'HEAD' - try: - response = urlopen(request) - except HTTPError as e: - print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), - file=stderr) - print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) - sys.exit(1) - return version - - -# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-06-19 -# For easy maintainability, please keep this manually-inputted dictionary sorted by key. -EC2_INSTANCE_TYPES = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c4.large": "hvm", - "c4.xlarge": "hvm", - "c4.2xlarge": "hvm", - "c4.4xlarge": "hvm", - "c4.8xlarge": "hvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "d2.xlarge": "hvm", - "d2.2xlarge": "hvm", - "d2.4xlarge": "hvm", - "d2.8xlarge": "hvm", - "g2.2xlarge": "hvm", - "g2.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.xlarge": "hvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "m1.small": "pvm", - "m1.medium": "pvm", - "m1.large": "pvm", - "m1.xlarge": "pvm", - "m2.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m3.medium": "hvm", - "m3.large": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", - "m4.large": "hvm", - "m4.xlarge": "hvm", - "m4.2xlarge": "hvm", - "m4.4xlarge": "hvm", - "m4.10xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "t1.micro": "pvm", - "t2.micro": "hvm", - "t2.small": "hvm", - "t2.medium": "hvm", - "t2.large": "hvm", -} - - -def get_tachyon_version(spark_version): - return SPARK_TACHYON_MAP.get(spark_version, "") - - -# Attempt to resolve an appropriate AMI given the architecture and region of the request. -def get_spark_ami(opts): - if opts.instance_type in EC2_INSTANCE_TYPES: - instance_type = EC2_INSTANCE_TYPES[opts.instance_type] - else: - instance_type = "pvm" - print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) - - # URL prefix from which to fetch AMI information - ami_prefix = "{r}/{b}/ami-list".format( - r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), - b=opts.spark_ec2_git_branch) - - ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) - reader = codecs.getreader("ascii") - try: - ami = reader(urlopen(ami_path)).read().strip() - except: - print("Could not resolve AMI at: " + ami_path, file=stderr) - sys.exit(1) - - print("Spark AMI: " + ami) - return ami - - -# Launch a cluster of the given name, by setting up its security groups, -# and then starting new instances in them. -# Returns a tuple of EC2 reservation objects for the master and slaves -# Fails if there already instances running in the cluster's groups. -def launch_cluster(conn, opts, cluster_name): - if opts.identity_file is None: - print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) - sys.exit(1) - - if opts.key_pair is None: - print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) - sys.exit(1) - - user_data_content = None - if opts.user_data: - with open(opts.user_data) as user_data_file: - user_data_content = user_data_file.read() - - print("Setting up security groups...") - master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) - slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) - authorized_address = opts.authorized_address - if master_group.rules == []: # Group was just now created - if opts.vpc_id is None: - master_group.authorize(src_group=master_group) - master_group.authorize(src_group=slave_group) - else: - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize('tcp', 22, 22, authorized_address) - master_group.authorize('tcp', 8080, 8081, authorized_address) - master_group.authorize('tcp', 18080, 18080, authorized_address) - master_group.authorize('tcp', 19999, 19999, authorized_address) - master_group.authorize('tcp', 50030, 50030, authorized_address) - master_group.authorize('tcp', 50070, 50070, authorized_address) - master_group.authorize('tcp', 60070, 60070, authorized_address) - master_group.authorize('tcp', 4040, 4045, authorized_address) - # Rstudio (GUI for R) needs port 8787 for web access - master_group.authorize('tcp', 8787, 8787, authorized_address) - # HDFS NFS gateway requires 111,2049,4242 for tcp & udp - master_group.authorize('tcp', 111, 111, authorized_address) - master_group.authorize('udp', 111, 111, authorized_address) - master_group.authorize('tcp', 2049, 2049, authorized_address) - master_group.authorize('udp', 2049, 2049, authorized_address) - master_group.authorize('tcp', 4242, 4242, authorized_address) - master_group.authorize('udp', 4242, 4242, authorized_address) - # RM in YARN mode uses 8088 - master_group.authorize('tcp', 8088, 8088, authorized_address) - if opts.ganglia: - master_group.authorize('tcp', 5080, 5080, authorized_address) - if slave_group.rules == []: # Group was just now created - if opts.vpc_id is None: - slave_group.authorize(src_group=master_group) - slave_group.authorize(src_group=slave_group) - else: - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize('tcp', 22, 22, authorized_address) - slave_group.authorize('tcp', 8080, 8081, authorized_address) - slave_group.authorize('tcp', 50060, 50060, authorized_address) - slave_group.authorize('tcp', 50075, 50075, authorized_address) - slave_group.authorize('tcp', 60060, 60060, authorized_address) - slave_group.authorize('tcp', 60075, 60075, authorized_address) - - # Check if instances are already running in our groups - existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, - die_on_error=False) - if existing_slaves or (existing_masters and not opts.use_existing_master): - print("ERROR: There are already instances running in group %s or %s" % - (master_group.name, slave_group.name), file=stderr) - sys.exit(1) - - # Figure out Spark AMI - if opts.ami is None: - opts.ami = get_spark_ami(opts) - - # we use group ids to work around https://github.com/boto/boto/issues/350 - additional_group_ids = [] - if opts.additional_security_group: - additional_group_ids = [sg.id - for sg in conn.get_all_security_groups() - if opts.additional_security_group in (sg.name, sg.id)] - print("Launching instances...") - - try: - image = conn.get_all_images(image_ids=[opts.ami])[0] - except: - print("Could not find AMI " + opts.ami, file=stderr) - sys.exit(1) - - # Create block device mapping so that we can add EBS volumes if asked to. - # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz - block_map = BlockDeviceMapping() - if opts.ebs_vol_size > 0: - for i in range(opts.ebs_vol_num): - device = EBSBlockDeviceType() - device.size = opts.ebs_vol_size - device.volume_type = opts.ebs_vol_type - device.delete_on_termination = True - block_map["/dev/sd" + chr(ord('s') + i)] = device - - # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342). - if opts.instance_type.startswith('m3.'): - for i in range(get_num_disks(opts.instance_type)): - dev = BlockDeviceType() - dev.ephemeral_name = 'ephemeral%d' % i - # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.ascii_letters[i + 1] - block_map[name] = dev - - # Launch slaves - if opts.spot_price is not None: - # Launch spot instances with the requested price - print("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - my_req_ids = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - slave_reqs = conn.request_spot_instances( - price=opts.spot_price, - image_id=opts.ami, - launch_group="launch-group-%s" % cluster_name, - placement=zone, - count=num_slaves_this_zone, - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_profile_name=opts.instance_profile_name) - my_req_ids += [req.id for req in slave_reqs] - i += 1 - - print("Waiting for spot instances to be granted...") - try: - while True: - time.sleep(10) - reqs = conn.get_all_spot_instance_requests() - id_to_req = {} - for r in reqs: - id_to_req[r.id] = r - active_instance_ids = [] - for i in my_req_ids: - if i in id_to_req and id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) - if len(active_instance_ids) == opts.slaves: - print("All %d slaves granted" % opts.slaves) - reservations = conn.get_all_reservations(active_instance_ids) - slave_nodes = [] - for r in reservations: - slave_nodes += r.instances - break - else: - print("%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves)) - except: - print("Canceling spot instance requests") - conn.cancel_spot_instance_requests(my_req_ids) - # Log a warning if any of these requests actually launched instances: - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - running = len(master_nodes) + len(slave_nodes) - if running: - print(("WARNING: %d instances are still running" % running), file=stderr) - sys.exit(0) - else: - # Launch non-spot instances - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - slave_nodes = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - if num_slaves_this_zone > 0: - slave_res = image.run( - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - slave_nodes += slave_res.instances - print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( - s=num_slaves_this_zone, - plural_s=('' if num_slaves_this_zone == 1 else 's'), - z=zone, - r=slave_res.id)) - i += 1 - - # Launch or resume masters - if existing_masters: - print("Starting master...") - for inst in existing_masters: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - master_nodes = existing_masters - else: - master_type = opts.master_instance_type - if master_type == "": - master_type = opts.instance_type - if opts.zone == 'all': - opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run( - key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - - master_nodes = master_res.instances - print("Launched master in %s, regid = %s" % (zone, master_res.id)) - - # This wait time corresponds to SPARK-4983 - print("Waiting for AWS to propagate instance metadata...") - time.sleep(15) - - # Give the instances descriptive names and set additional tags - additional_tags = {} - if opts.additional_tags.strip(): - additional_tags = dict( - map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') - ) - - for master in master_nodes: - master.add_tags( - dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) - ) - - for slave in slave_nodes: - slave.add_tags( - dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) - ) - - # Return all the instances - return (master_nodes, slave_nodes) - - -def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - """ - Get the EC2 instances in an existing cluster if available. - Returns a tuple of lists of EC2 instance objects for the masters and slaves. - """ - print("Searching for existing cluster {c} in region {r}...".format( - c=cluster_name, r=opts.region)) - - def get_instances(group_names): - """ - Get all non-terminated instances that belong to any of the provided security groups. - - EC2 reservation filters and instance states are documented here: - http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options - """ - reservations = conn.get_all_reservations( - filters={"instance.group-name": group_names}) - instances = itertools.chain.from_iterable(r.instances for r in reservations) - return [i for i in instances if i.state not in ["shutting-down", "terminated"]] - - master_instances = get_instances([cluster_name + "-master"]) - slave_instances = get_instances([cluster_name + "-slaves"]) - - if any((master_instances, slave_instances)): - print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( - m=len(master_instances), - plural_m=('' if len(master_instances) == 1 else 's'), - s=len(slave_instances), - plural_s=('' if len(slave_instances) == 1 else 's'))) - - if not master_instances and die_on_error: - print("ERROR: Could not find a master for cluster {c} in region {r}.".format( - c=cluster_name, r=opts.region), file=sys.stderr) - sys.exit(1) - - return (master_instances, slave_instances) - - -# Deploy configuration files and run setup scripts on a newly launched -# or started EC2 cluster. -def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = get_dns_name(master_nodes[0], opts.private_ips) - if deploy_ssh_key: - print("Generating cluster's SSH key on master...") - key_setup = """ - [ -f ~/.ssh/id_rsa ] || - (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && - cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys) - """ - ssh(master, opts, key_setup) - dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print("Transferring cluster's SSH key to slaves...") - for slave in slave_nodes: - slave_address = get_dns_name(slave, opts.private_ips) - print(slave_address) - ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) - - modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] - - if opts.hadoop_major_version == "1": - modules = list(filter(lambda x: x != "mapreduce", modules)) - - if opts.ganglia: - modules.append('ganglia') - - # Clear SPARK_WORKER_INSTANCES if running on YARN - if opts.hadoop_major_version == "yarn": - opts.worker_instances = "" - - # NOTE: We should clone the repository before running deploy_files to - # prevent ec2-variables.sh from being overwritten - print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) - ssh( - host=master, - opts=opts, - command="rm -rf spark-ec2" - + " && " - + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, - b=opts.spark_ec2_git_branch) - ) - - print("Deploying files to master...") - deploy_files( - conn=conn, - root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", - opts=opts, - master_nodes=master_nodes, - slave_nodes=slave_nodes, - modules=modules - ) - - if opts.deploy_root_dir is not None: - print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) - deploy_user_files( - root_dir=opts.deploy_root_dir, - opts=opts, - master_nodes=master_nodes - ) - - print("Running setup on master...") - setup_spark_cluster(master, opts) - print("Done!") - - -def setup_spark_cluster(master, opts): - ssh(master, opts, "chmod u+x spark-ec2/setup.sh") - ssh(master, opts, "spark-ec2/setup.sh") - print("Spark standalone cluster started at http://%s:8080" % master) - - if opts.ganglia: - print("Ganglia started at http://%s:5080/ganglia" % master) - - -def is_ssh_available(host, opts, print_ssh_output=True): - """ - Check if SSH is available on a host. - """ - s = subprocess.Popen( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order - ) - cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout - - if s.returncode != 0 and print_ssh_output: - # extra leading newline is for spacing in wait_for_cluster_state() - print(textwrap.dedent("""\n - Warning: SSH connection error. (This could be temporary.) - Host: {h} - SSH return code: {r} - SSH output: {o} - """).format( - h=host, - r=s.returncode, - o=cmd_output.strip() - )) - - return s.returncode == 0 - - -def is_cluster_ssh_available(cluster_instances, opts): - """ - Check if SSH is available on all the instances in a cluster. - """ - for i in cluster_instances: - dns_name = get_dns_name(i, opts.private_ips) - if not is_ssh_available(host=dns_name, opts=opts): - return False - else: - return True - - -def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): - """ - Wait for all the instances in the cluster to reach a designated state. - - cluster_instances: a list of boto.ec2.instance.Instance - cluster_state: a string representing the desired state of all the instances in the cluster - value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as - 'running', 'terminated', etc. - (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) - """ - sys.stdout.write( - "Waiting for cluster to enter '{s}' state.".format(s=cluster_state) - ) - sys.stdout.flush() - - start_time = datetime.now() - num_attempts = 0 - - while True: - time.sleep(5 * num_attempts) # seconds - - for i in cluster_instances: - i.update() - - max_batch = 100 - statuses = [] - for j in xrange(0, len(cluster_instances), max_batch): - batch = [i.id for i in cluster_instances[j:j + max_batch]] - statuses.extend(conn.get_all_instance_status(instance_ids=batch)) - - if cluster_state == 'ssh-ready': - if all(i.state == 'running' for i in cluster_instances) and \ - all(s.system_status.status == 'ok' for s in statuses) and \ - all(s.instance_status.status == 'ok' for s in statuses) and \ - is_cluster_ssh_available(cluster_instances, opts): - break - else: - if all(i.state == cluster_state for i in cluster_instances): - break - - num_attempts += 1 - - sys.stdout.write(".") - sys.stdout.flush() - - sys.stdout.write("\n") - - end_time = datetime.now() - print("Cluster is now in '{s}' state. Waited {t} seconds.".format( - s=cluster_state, - t=(end_time - start_time).seconds - )) - - -# Get number of local disks available for a given EC2 instance type. -def get_num_disks(instance_type): - # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-06-19 - # For easy maintainability, please keep this manually-inputted dictionary sorted by key. - disks_by_instance = { - "c1.medium": 1, - "c1.xlarge": 4, - "c3.large": 2, - "c3.xlarge": 2, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "c4.large": 0, - "c4.xlarge": 0, - "c4.2xlarge": 0, - "c4.4xlarge": 0, - "c4.8xlarge": 0, - "cc1.4xlarge": 2, - "cc2.8xlarge": 4, - "cg1.4xlarge": 2, - "cr1.8xlarge": 2, - "d2.xlarge": 3, - "d2.2xlarge": 6, - "d2.4xlarge": 12, - "d2.8xlarge": 24, - "g2.2xlarge": 1, - "g2.8xlarge": 2, - "hi1.4xlarge": 2, - "hs1.8xlarge": 24, - "i2.xlarge": 1, - "i2.2xlarge": 2, - "i2.4xlarge": 4, - "i2.8xlarge": 8, - "m1.small": 1, - "m1.medium": 1, - "m1.large": 2, - "m1.xlarge": 4, - "m2.xlarge": 1, - "m2.2xlarge": 1, - "m2.4xlarge": 2, - "m3.medium": 1, - "m3.large": 1, - "m3.xlarge": 2, - "m3.2xlarge": 2, - "m4.large": 0, - "m4.xlarge": 0, - "m4.2xlarge": 0, - "m4.4xlarge": 0, - "m4.10xlarge": 0, - "r3.large": 1, - "r3.xlarge": 1, - "r3.2xlarge": 1, - "r3.4xlarge": 1, - "r3.8xlarge": 2, - "t1.micro": 0, - "t2.micro": 0, - "t2.small": 0, - "t2.medium": 0, - "t2.large": 0, - } - if instance_type in disks_by_instance: - return disks_by_instance[instance_type] - else: - print("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type, file=stderr) - return 1 - - -# Deploy the configuration file templates in a given local directory to -# a cluster, filling in any template parameters with information about the -# cluster (e.g. lists of masters and slaves). Files are only deployed to -# the first master instance in the cluster, and we expect the setup -# script to be run on that instance to copy them to other nodes. -# -# root_dir should be an absolute path to the directory with the files we want to deploy. -def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - - num_disks = get_num_disks(opts.instance_type) - hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" - mapred_local_dirs = "/mnt/hadoop/mrlocal" - spark_local_dirs = "/mnt/spark" - if num_disks > 1: - for i in range(2, num_disks + 1): - hdfs_data_dirs += ",/mnt%d/ephemeral-hdfs/data" % i - mapred_local_dirs += ",/mnt%d/hadoop/mrlocal" % i - spark_local_dirs += ",/mnt%d/spark" % i - - cluster_url = "%s:7077" % active_master - - if "." in opts.spark_version: - # Pre-built Spark deploy - spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - tachyon_v = get_tachyon_version(spark_v) - else: - # Spark-only custom deploy - spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) - tachyon_v = "" - print("Deploying Spark via git hash; Tachyon won't be set up") - modules = filter(lambda x: x != "tachyon", modules) - - master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] - slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] - worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" - template_vars = { - "master_list": '\n'.join(master_addresses), - "active_master": active_master, - "slave_list": '\n'.join(slave_addresses), - "cluster_url": cluster_url, - "hdfs_data_dirs": hdfs_data_dirs, - "mapred_local_dirs": mapred_local_dirs, - "spark_local_dirs": spark_local_dirs, - "swap": str(opts.swap), - "modules": '\n'.join(modules), - "spark_version": spark_v, - "tachyon_version": tachyon_v, - "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": worker_instances_str, - "spark_master_opts": opts.master_opts - } - - if opts.copy_aws_credentials: - template_vars["aws_access_key_id"] = conn.aws_access_key_id - template_vars["aws_secret_access_key"] = conn.aws_secret_access_key - else: - template_vars["aws_access_key_id"] = "" - template_vars["aws_secret_access_key"] = "" - - # Create a temp directory in which we will place all the files to be - # deployed after we substitue template parameters in them - tmp_dir = tempfile.mkdtemp() - for path, dirs, files in os.walk(root_dir): - if path.find(".svn") == -1: - dest_dir = os.path.join('/', path[len(root_dir):]) - local_dir = tmp_dir + dest_dir - if not os.path.exists(local_dir): - os.makedirs(local_dir) - for filename in files: - if filename[0] not in '#.~' and filename[-1] != '~': - dest_file = os.path.join(dest_dir, filename) - local_file = tmp_dir + dest_file - with open(os.path.join(path, filename)) as src: - with open(local_file, "w") as dest: - text = src.read() - for key in template_vars: - text = text.replace("{{" + key + "}}", template_vars[key]) - dest.write(text) - dest.close() - # rsync the whole directory over to the master machine - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s/" % tmp_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - # Remove the temp directory we created above - shutil.rmtree(tmp_dir) - - -# Deploy a given local directory to a cluster, WITHOUT parameter substitution. -# Note that unlike deploy_files, this works for binary files. -# Also, it is up to the user to add (or not) the trailing slash in root_dir. -# Files are only deployed to the first master instance in the cluster. -# -# root_dir should be an absolute path. -def deploy_user_files(root_dir, opts, master_nodes): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s" % root_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - - -def stringify_command(parts): - if isinstance(parts, str): - return parts - else: - return ' '.join(map(pipes.quote, parts)) - - -def ssh_args(opts): - parts = ['-o', 'StrictHostKeyChecking=no'] - parts += ['-o', 'UserKnownHostsFile=/dev/null'] - if opts.identity_file is not None: - parts += ['-i', opts.identity_file] - return parts - - -def ssh_command(opts): - return ['ssh'] + ssh_args(opts) - - -# Run a command on a host through ssh, retrying up to five times -# and then throwing an exception if ssh continues to fail. -def ssh(host, opts, command): - tries = 0 - while True: - try: - return subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), - stringify_command(command)]) - except subprocess.CalledProcessError as e: - if tries > 5: - # If this was an ssh failure, provide the user with hints. - if e.returncode == 255: - raise UsageError( - "Failed to SSH to remote host {0}.\n" - "Please check that you have provided the correct --identity-file and " - "--key-pair parameters and try again.".format(host)) - else: - raise e - print("Error executing remote command, retrying after 30 seconds: {0}".format(e), - file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990) -def _check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - -def ssh_read(host, opts, command): - return _check_output( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) - - -def ssh_write(host, opts, command, arguments): - tries = 0 - while True: - proc = subprocess.Popen( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], - stdin=subprocess.PIPE) - proc.stdin.write(arguments) - proc.stdin.close() - status = proc.wait() - if status == 0: - break - elif tries > 5: - raise RuntimeError("ssh_write failed with error %s" % proc.returncode) - else: - print("Error {0} while executing remote command, retrying after 30 seconds". - format(status), file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Gets a list of zones to launch instances in -def get_zones(conn, opts): - if opts.zone == 'all': - zones = [z.name for z in conn.get_all_zones()] - else: - zones = [opts.zone] - return zones - - -# Gets the number of items in a partition -def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total // num_partitions - if (total % num_partitions) - current_partitions > 0: - num_slaves_this_zone += 1 - return num_slaves_this_zone - - -# Gets the IP address, taking into account the --private-ips flag -def get_ip_address(instance, private_ips=False): - ip = instance.ip_address if not private_ips else \ - instance.private_ip_address - return ip - - -# Gets the DNS name, taking into account the --private-ips flag -def get_dns_name(instance, private_ips=False): - dns = instance.public_dns_name if not private_ips else \ - instance.private_ip_address - if not dns: - raise UsageError("Failed to determine hostname of {0}.\n" - "Please check that you provided --private-ips if " - "necessary".format(instance)) - return dns - - -def real_main(): - (opts, action, cluster_name) = parse_args() - - # Input parameter validation - get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - - if opts.wait is not None: - # NOTE: DeprecationWarnings are silent in 2.7+ by default. - # To show them, run Python with the -Wdefault switch. - # See: https://docs.python.org/3.5/whatsnew/2.7.html - warnings.warn( - "This option is deprecated and has no effect. " - "spark-ec2 automatically waits as long as necessary for clusters to start up.", - DeprecationWarning - ) - - if opts.identity_file is not None: - if not os.path.exists(opts.identity_file): - print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - file_mode = os.stat(opts.identity_file).st_mode - if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print("ERROR: The identity file must be accessible only by you.", file=stderr) - print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - if opts.instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type), file=stderr) - - if opts.master_instance_type != "": - if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type), file=stderr) - # Since we try instance types even if we can't resolve them, we check if they resolve first - # and, if they do, see if they resolve to the same virtualization type. - if opts.instance_type in EC2_INSTANCE_TYPES and \ - opts.master_instance_type in EC2_INSTANCE_TYPES: - if EC2_INSTANCE_TYPES[opts.instance_type] != \ - EC2_INSTANCE_TYPES[opts.master_instance_type]: - print("Error: spark-ec2 currently does not support having a master and slaves " - "with different AMI virtualization types.", file=stderr) - print("master instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) - print("slave instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) - sys.exit(1) - - if opts.ebs_vol_num > 8: - print("ebs-vol-num cannot be greater than 8", file=stderr) - sys.exit(1) - - # Prevent breaking ami_prefix (/, .git and startswith checks) - # Prevent forks with non spark-ec2 names for now. - if opts.spark_ec2_git_repo.endswith("/") or \ - opts.spark_ec2_git_repo.endswith(".git") or \ - not opts.spark_ec2_git_repo.startswith("https://github.com") or \ - not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " - "Furthermore, we currently only support forks named spark-ec2.", file=stderr) - sys.exit(1) - - if not (opts.deploy_root_dir is None or - (os.path.isabs(opts.deploy_root_dir) and - os.path.isdir(opts.deploy_root_dir) and - os.path.exists(opts.deploy_root_dir))): - print("--deploy-root-dir must be an absolute path to a directory that exists " - "on the local file system", file=stderr) - sys.exit(1) - - try: - if opts.profile is None: - conn = ec2.connect_to_region(opts.region) - else: - conn = ec2.connect_to_region(opts.region, profile_name=opts.profile) - except Exception as e: - print((e), file=stderr) - sys.exit(1) - - # Select an AZ at random if it was not specified. - if opts.zone == "": - opts.zone = random.choice(conn.get_all_zones()).name - - if action == "launch": - if opts.slaves <= 0: - print("ERROR: You have to start at least 1 slave", file=sys.stderr) - sys.exit(1) - if opts.resume: - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - else: - (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - setup_cluster(conn, master_nodes, slave_nodes, opts, True) - - elif action == "destroy": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - - if any(master_nodes + slave_nodes): - print("The following instances will be terminated:") - for inst in master_nodes + slave_nodes: - print("> %s" % get_dns_name(inst, opts.private_ips)) - print("ALL DATA ON ALL NODES WILL BE LOST!!") - - msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) - response = raw_input(msg) - if response == "y": - print("Terminating master...") - for inst in master_nodes: - inst.terminate() - print("Terminating slaves...") - for inst in slave_nodes: - inst.terminate() - - # Delete security groups as well - if opts.delete_groups: - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='terminated' - ) - print("Deleting security groups (this will take some time)...") - attempt = 1 - while attempt <= 3: - print("Attempt %d" % attempt) - groups = [g for g in conn.get_all_security_groups() if g.name in group_names] - success = True - # Delete individual rules in all groups before deleting groups to - # remove dependencies between them - for group in groups: - print("Deleting rules in security group " + group.name) - for rule in group.rules: - for grant in rule.grants: - success &= group.revoke(ip_protocol=rule.ip_protocol, - from_port=rule.from_port, - to_port=rule.to_port, - src_group=grant) - - # Sleep for AWS eventual-consistency to catch up, and for instances - # to terminate - time.sleep(30) # Yes, it does have to be this long :-( - for group in groups: - try: - # It is needed to use group_id to make it work with VPC - conn.delete_security_group(group_id=group.id) - print("Deleted security group %s" % group.name) - except boto.exception.EC2ResponseError: - success = False - print("Failed to delete security group %s" % group.name) - - # Unfortunately, group.revoke() returns True even if a rule was not - # deleted, so this needs to be rerun if something fails - if success: - break - - attempt += 1 - - if not success: - print("Failed to delete all security groups after 3 tries.") - print("Try re-running in a few minutes.") - - elif action == "login": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - master = get_dns_name(master_nodes[0], opts.private_ips) - print("Logging into master " + master + "...") - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) - - elif action == "reboot-slaves": - response = raw_input( - "Are you sure you want to reboot the cluster " + - cluster_name + " slaves?\n" + - "Reboot cluster slaves " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Rebooting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - print("Rebooting " + inst.id) - inst.reboot() - - elif action == "get-master": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - print(get_dns_name(master_nodes[0], opts.private_ips)) - - elif action == "stop": - response = raw_input( - "Are you sure you want to stop the cluster " + - cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " + - "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" + - "AMAZON EBS IF IT IS EBS-BACKED!!\n" + - "All data on spot-instance slaves will be lost.\n" + - "Stop cluster " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Stopping master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.stop() - print("Stopping slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - if inst.spot_instance_request_id: - inst.terminate() - else: - inst.stop() - - elif action == "start": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print("Starting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - print("Starting master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - - # Determine types of running instances - existing_master_type = master_nodes[0].instance_type - existing_slave_type = slave_nodes[0].instance_type - # Setting opts.master_instance_type to the empty string indicates we - # have the same instance type for the master and the slaves - if existing_master_type == existing_slave_type: - existing_master_type = "" - opts.master_instance_type = existing_master_type - opts.instance_type = existing_slave_type - - setup_cluster(conn, master_nodes, slave_nodes, opts, False) - - else: - print("Invalid action: %s" % action, file=stderr) - sys.exit(1) - - -def main(): - try: - real_main() - except UsageError as e: - print("\nError:\n", e, file=stderr) - sys.exit(1) - - -if __name__ == "__main__": - logging.basicConfig() - main() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index f6b0ecb02c100..b6c2916254056 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -30,7 +30,7 @@ lines = sc.textFile(sys.argv[1], 1) sortedCount = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (int(x), 1)) \ - .sortByKey(lambda x: x) + .sortByKey() # This is just a demo on how to bring all the sorted data back to a single node. # In reality, we wouldn't want to collect all the data to the driver node. output = sortedCount.collect() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 3834ea807acbf..c4336639d7c0b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -25,7 +25,7 @@ import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegression object IsotonicRegressionExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("IsotonicRegressionExample") val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala index 8bae1b9d1832d..0187ad603a654 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint object NaiveBayesExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("NaiveBayesExample") val sc = new SparkContext(conf) // $example on$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala index ace16ff1ea225..add634c957b40 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.SQLContext object RegressionMetricsExample { - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("RegressionMetricsExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index c4e18d92eefa9..d7885d7cc1ae1 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -385,7 +385,7 @@ object KafkaCluster { val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp => val hpa = hp.split(":") if (hpa.size == 1) { - throw new SparkException(s"Broker not the in correct format of : [$brokers]") + throw new SparkException(s"Broker not in the correct format of : [$brokers]") } (hpa(0), hpa(1).toInt) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 603be22818206..4eb155645867b 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -156,7 +156,7 @@ class KafkaRDD[ var requestOffset = part.fromOffset var iter: Iterator[MessageAndOffset] = null - // The idea is to use the provided preferred host, except on task retry atttempts, + // The idea is to use the provided preferred host, except on task retry attempts, // to minimize number of kafka metadata requests private def connectLeader: SimpleConsumer = { if (context.attemptNumber > 0) { diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index b3ba72a0087ad..d3a2bf5825b08 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -51,7 +51,7 @@ org.eclipse.paho org.eclipse.paho.client.mqttv3 - 1.0.1 + 1.0.2 org.scalacheck diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index de749626ec09c..6a73bc0e30d05 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.util.Random -import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials} +import com.amazonaws.auth.{BasicAWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 3321c7527edb4..5223c81a8e0e0 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -24,10 +24,10 @@ import com.amazonaws.services.kinesis.model.Record import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo -import org.apache.spark.streaming.{Duration, StreamingContext, Time} private[kinesis] class KinesisInputDStream[T: ClassTag]( _ssc: StreamingContext, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index abb9b6cd32f1c..48ee2a959786b 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2de6195716e5c..15ac588b82587 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -24,9 +24,9 @@ import com.amazonaws.services.kinesis.model.Record import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, StreamingContext} import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.{Duration, StreamingContext} object KinesisUtils { /** diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index d85b4cda8ce98..e6f504c4e54dd 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.kinesis import org.scalatest.BeforeAndAfterAll -import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) extends KinesisFunSuite with BeforeAndAfterAll { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala index 645e64a0bc3a0..e1499a8220991 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.kinesis -import java.util.concurrent.{TimeoutException, ExecutorService} +import java.util.concurrent.{ExecutorService, TimeoutException} import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ @@ -28,7 +28,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach} +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ import org.scalatest.mock.MockitoSugar diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index e5c70db554a27..fd15b6ccdc889 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -27,8 +27,8 @@ import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito._ -import org.scalatest.mock.MockitoSugar import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} import org.apache.spark.util.Utils diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 78263f9dca65c..ee6a5f0390d04 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -25,10 +25,11 @@ import scala.util.Random import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} @@ -38,7 +39,6 @@ import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext} abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index fc36e12dd2aed..d048fb5d561f3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.SparkException -import org.apache.spark.SparkContext._ import org.apache.spark.graphx.lib._ import org.apache.spark.rdd.RDD @@ -379,7 +378,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergenceWithOptions]] */ def personalizedPageRank(src: VertexId, tol: Double, - resetProb: Double = 0.15) : Graph[Double, Double] = { + resetProb: Double = 0.15): Graph[Double, Double] = { PageRank.runUntilConvergenceWithOptions(graph, tol, resetProb, Some(src)) } @@ -392,7 +391,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @see [[org.apache.spark.graphx.lib.PageRank$#runWithOptions]] */ def staticPersonalizedPageRank(src: VertexId, numIter: Int, - resetProb: Double = 0.15) : Graph[Double, Double] = { + resetProb: Double = 0.15): Graph[Double, Double] = { PageRank.runWithOptions(graph, numIter, resetProb, Some(src)) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index f79f9c7ec448f..b4bec7cba5207 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -41,8 +41,8 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * shipping level. */ def withEdges[VD2: ClassTag, ED2: ClassTag]( - edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { - new ReplicatedVertexView(edges_, hasSrcId, hasDstId) + _edges: EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + new ReplicatedVertexView(_edges, hasSrcId, hasDstId) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index 3f203c4eca485..96d807f9f9ceb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -102,8 +102,8 @@ class ShippableVertexPartition[VD: ClassTag]( extends VertexPartitionBase[VD] { /** Return a new ShippableVertexPartition with the specified routing table. */ - def withRoutingTable(routingTable_ : RoutingTablePartition): ShippableVertexPartition[VD] = { - new ShippableVertexPartition(index, values, mask, routingTable_) + def withRoutingTable(_routingTable: RoutingTablePartition): ShippableVertexPartition[VD] = { + new ShippableVertexPartition(index, values, mask, _routingTable) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index f508b483a2f1b..7c680dcb99cd2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.collection.BitSet * example, [[VertexPartition.VertexPartitionOpsConstructor]]). */ private[graphx] abstract class VertexPartitionBaseOps - [VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor] + [VD: ClassTag, Self[X] <: VertexPartitionBase[X]: VertexPartitionBaseOpsConstructor] (self: Self[VD]) extends Serializable with Logging { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 35b26c998e1d9..46faad2e68c50 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -138,7 +138,7 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId , id: VertexId) => resetProb * delta(src, id) + (src: VertexId, id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } diff --git a/make-distribution.sh b/make-distribution.sh index a38fd8df17206..327659298e4d8 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -212,7 +212,6 @@ cp "$SPARK_HOME/README.md" "$DISTDIR" cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" cp -r "$SPARK_HOME/sbin" "$DISTDIR" -cp -r "$SPARK_HOME/ec2" "$DISTDIR" # Copy SparkR if it exists if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 08a51109d6c62..c41a611f1cc60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -113,13 +113,13 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } - val transformedDataset = model.transform(df).select(columns : _*) + val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) + updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName) } if (handlePersistence) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f9952434d2982..6cc9d025445c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -238,7 +238,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override def transform(dataset: DataFrame): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) - dataset.select(columnsToKeep.map(dataset.col) : _*) + dataset.select(columnsToKeep.map(dataset.col): _*) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0b215659b3672..716bc63e00995 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -102,7 +102,7 @@ class VectorAssembler(override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) + dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 6e87302c7779b..d3376a7dff938 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -474,7 +474,7 @@ private[ml] object RandomForest extends Logging { val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator @@ -825,7 +825,7 @@ private[ml] object RandomForest extends Logging { protected[tree] def findSplits( input: RDD[LabeledPoint], metadata: DecisionTreeMetadata, - seed : Long): Array[Array[Split]] = { + seed: Long): Array[Array[Split]] = { logDebug("isMulticlass = " + metadata.isMulticlass) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 7443097492d82..7a651a37ac77e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} -import org.apache.spark.sql.types.{DoubleType, DataType, StructType} +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** * Parameters for Decision Tree-based algorithms. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 5c9bc62cb09bb..16bc45bcb627f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -177,7 +177,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { } @Since("1.4.0") - override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + override def load(sc: SparkContext, path: String): GaussianMixtureModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats val k = (metadata \ "k").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 5273ed4d76650..1250bc1a07cb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.fpm -import java.lang.{Iterable => JavaIterable} import java.{util => ju} +import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -29,16 +29,15 @@ import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} +import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -134,7 +133,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { loadImpl(freqItemsets, sample) } - def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => val items = x.getAs[Seq[Item]](0).toArray val freq = x.getLong(1) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d7a74db0b1fd8..b08da4fb55034 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -279,7 +279,7 @@ class DenseMatrix @Since("1.3.0") ( } override def hashCode: Int = { - com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray) + com.google.common.base.Objects.hashCode(numRows: Integer, numCols: Integer, toArray) } private[mllib] def toBreeze: BM[Double] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 7abb1bf7ce967..a8c32f72bfdeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, - description : String, - normalizationMethod : RegressionNormalizationMethodType, + model: GeneralizedLinearModel, + description: String, + normalizationMethod: RegressionNormalizationMethodType, threshold: Double) extends PMMLModelExport { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala index b5b824bb9c9b6..255c6140e5410 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -26,14 +26,14 @@ import org.apache.spark.mllib.clustering.KMeansModel /** * PMML Model Export for KMeansModel class */ -private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ +private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModelExport{ populateKMeansPMML(model) /** * Export the input KMeansModel model to PMML format. */ - private def populateKMeansPMML(model : KMeansModel): Unit = { + private def populateKMeansPMML(model: KMeansModel): Unit = { pmml.getHeader.setDescription("k-means clustering") if (model.clusterCenters.length > 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index af1f7e74c004d..07ba0d8ccb2a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -25,10 +25,10 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl._ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ @@ -600,7 +600,7 @@ object DecisionTree extends Serializable with Logging { val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 729a211574822..1b71256c585bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -22,8 +22,8 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index a684cdd18c2fc..570a76f960796 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -26,9 +26,9 @@ import org.apache.spark.Logging import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.Impurities diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 66f0908c1250f..b373c2de3ea96 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -83,7 +83,7 @@ class Node @Since("1.2.0") ( * @return predicted value */ @Since("1.1.0") - def predict(features: Vector) : Double = { + def predict(features: Vector): Double = { if (isLeaf) { predict.predict } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 094528e2ece06..240781bcd335b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -175,7 +175,7 @@ object LinearDataGenerator { nfeatures: Int, eps: Double, nparts: Int = 2, - intercept: Double = 0.0) : RDD[LabeledPoint] = { + intercept: Double = 0.0): RDD[LabeledPoint] = { val random = new Random(42) // Random values distributed uniformly in [-0.5, 0.5] val w = Array.fill(nfeatures)(random.nextDouble() - 0.5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index ee3c85d09a463..1a47344b68937 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -45,7 +45,7 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) + val weightsMat = new DoubleMatrix(1, weights.length, weights: _*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index 1142102bb040e..50441816ece3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { - override def maxWaitTimeMillis : Int = 30000 + override def maxWaitTimeMillis: Int = 30000 test("accuracy for null hypothesis using welch t-test") { // set parameters diff --git a/network/common/pom.xml b/network/common/pom.xml index 92ca0046d4f53..eda2b7307088f 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -55,6 +55,7 @@ com.google.guava guava + compile diff --git a/pom.xml b/pom.xml index 9c975a45f8d23..fc5cf970e0601 100644 --- a/pom.xml +++ b/pom.xml @@ -152,9 +152,9 @@ 1.7.7 hadoop2 0.7.1 - 1.4.0 + 1.6.1 - 0.10.1 + 0.10.2 4.3.2 @@ -167,7 +167,7 @@ ${scala.version} org.scala-lang 1.9.13 - 2.4.4 + 2.5.3 1.1.2 1.1.2 1.2.0-incubating @@ -226,93 +226,6 @@ false - - apache-repo - Apache Repository - https://repository.apache.org/content/repositories/releases - - true - - - false - - - - jboss-repo - JBoss Repository - https://repository.jboss.org/nexus/content/repositories/releases - - true - - - false - - - - mqtt-repo - MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases - - true - - - false - - - - cloudera-repo - Cloudera Repository - https://repository.cloudera.com/artifactory/cloudera-repos - - true - - - false - - - - spark-hive-staging - Staging Repo for Hive 1.2.1 (Spark Version) - https://oss.sonatype.org/content/repositories/orgspark-project-1113 - - true - - - - mapr-repo - MapR Repository - http://repository.mapr.com/maven/ - - true - - - false - - - - - spring-releases - Spring Release Repository - https://repo.spring.io/libs-release - - false - - - false - - - - - twttr-repo - Twttr Repository - http://maven.twttr.com - - true - - - false - - @@ -1133,6 +1046,12 @@ zookeeper ${zookeeper.version} ${hadoop.deps.scope} + + + org.jboss.netty + netty + + org.codehaus.jackson @@ -1858,7 +1777,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4 + 1.4.1 enforce-versions @@ -1873,6 +1792,19 @@ ${java.version} + + + + org.jboss.netty + + true + diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 9ba9f8286f10c..41856443af49b 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,11 +91,16 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" + // The resolvers setting for MQTT Repository is needed for mqttv3(1.0.1) + // because spark-streaming-mqtt(1.6.0) depends on it. + // Remove the setting on updating previousSparkVersion. val previousSparkVersion = "1.6.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), - binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) + binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value), + sbt.Keys.resolvers += + "MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases") } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5d4f19ab14a29..4c34c888cfd5e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -141,7 +141,12 @@ object SparkBuild extends PomBuild { publishMavenStyle := true, unidocGenjavadocVersion := "0.9-spark0", - resolvers += Resolver.mavenLocal, + // Override SBT's default resolvers: + resolvers := Seq( + DefaultMavenRepository, + Resolver.mavenLocal + ), + externalResolvers := resolvers.value, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) diff --git a/project/plugins.sbt b/project/plugins.sbt index 15ba3a36d51ca..822a7c4a82d5e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,9 +1,3 @@ -resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) - -resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" - -resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" - addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 9714c46fe99a0..2439a1f715aba 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -187,6 +187,16 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + java,scala,3rdParty,spark + javax?\..* + scala\..* + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + @@ -207,17 +217,6 @@ This file is divided into 3 sections: - - - - java,scala,3rdParty,spark - javax?\..* - scala\..* - (?!org\.apache\.spark\.).* - org\.apache\.spark\..* - - - diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index cad770122d150..aabb5d49582c8 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -223,7 +223,12 @@ precedenceUnaryPrefixExpression ; precedenceUnarySuffixExpression - : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + : + ( + (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN + | + precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + ) -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) -> precedenceUnaryPrefixExpression ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index b04bb677774c5..2c13d3056f468 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -1,9 +1,9 @@ /** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with + (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -582,7 +582,7 @@ import java.util.HashMap; return header; } - + @Override public String getErrorMessage(RecognitionException e, String[] tokenNames) { String msg = null; @@ -619,7 +619,7 @@ import java.util.HashMap; } return msg; } - + public void pushMsg(String msg, RecognizerSharedState state) { // ANTLR generated code does not wrap the @init code wit this backtracking check, // even if the matching @after has it. If we have parser rules with that are doing @@ -639,7 +639,7 @@ import java.util.HashMap; // counter to generate unique union aliases private int aliasCounter; private String generateUnionAlias() { - return "_u" + (++aliasCounter); + return "u_" + (++aliasCounter); } private char [] excludedCharForColumnName = {'.', ':'}; private boolean containExcludedCharForCreateTableColumnName(String input) { @@ -1235,7 +1235,7 @@ alterTblPartitionStatementSuffixSkewedLocation : KW_SET KW_SKEWED KW_LOCATION skewedLocations -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) ; - + skewedLocations @init { pushMsg("skewed locations", state); } @after { popMsg(state); } @@ -1264,7 +1264,7 @@ alterStatementSuffixLocation -> ^(TOK_ALTERTABLE_LOCATION $newLoc) ; - + alterStatementSuffixSkewedby @init {pushMsg("alter skewed by statement", state);} @after{popMsg(state);} @@ -1336,10 +1336,10 @@ tabTypeExpr (identifier (DOT^ ( (KW_ELEM_TYPE) => KW_ELEM_TYPE - | + | (KW_KEY_TYPE) => KW_KEY_TYPE - | - (KW_VALUE_TYPE) => KW_VALUE_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier ))* )? @@ -1376,7 +1376,7 @@ descStatement analyzeStatement @init { pushMsg("analyze statement", state); } @after { popMsg(state); } - : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) ; @@ -1389,7 +1389,7 @@ showStatement | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? -> ^(TOK_SHOWCOLUMNS tableName $db_name?) | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) - | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) | KW_SHOW KW_CREATE ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) | @@ -1398,7 +1398,7 @@ showStatement | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) - | KW_SHOW KW_LOCKS + | KW_SHOW KW_LOCKS ( (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) | @@ -1511,7 +1511,7 @@ showCurrentRole setRole @init {pushMsg("set role", state);} @after {popMsg(state);} - : KW_SET KW_ROLE + : KW_SET KW_ROLE ( (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) | @@ -1966,7 +1966,7 @@ columnNameOrderList skewedValueElement @init { pushMsg("skewed value element", state); } @after { popMsg(state); } - : + : skewedColumnValues | skewedColumnValuePairList ; @@ -1980,8 +1980,8 @@ skewedColumnValuePairList skewedColumnValuePair @init { pushMsg("column value pair", state); } @after { popMsg(state); } - : - LPAREN colValues=skewedColumnValues RPAREN + : + LPAREN colValues=skewedColumnValues RPAREN -> ^(TOK_TABCOLVALUES $colValues) ; @@ -2001,11 +2001,11 @@ skewedColumnValue skewedValueLocationElement @init { pushMsg("skewed value location element", state); } @after { popMsg(state); } - : + : skewedColumnValue | skewedColumnValuePair ; - + columnNameOrder @init { pushMsg("column name order", state); } @after { popMsg(state); } @@ -2118,7 +2118,7 @@ unionType @after { popMsg(state); } : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) ; - + setOperator @init { pushMsg("set operator", state); } @after { popMsg(state); } @@ -2172,7 +2172,7 @@ fromStatement[boolean topLevel] {adaptor.create(Identifier, generateUnionAlias())} ) ) - ^(TOK_INSERT + ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) ) @@ -2414,8 +2414,8 @@ setColumnsClause KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) ; -/* - UPDATE +/* + UPDATE
    SET col1 = val1, col2 = val2... WHERE ... */ updateStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 1eda4a9a97644..2e3cc0bfde7c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -22,10 +22,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e362b55d80cd1..8a33af8207350 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,8 +86,7 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic, - ComputeCurrentTime), + PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1229,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - /** * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index e8b2fcf819bf6..a8f89ce6de457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -110,7 +110,9 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias + .map(a => Subquery(a, tableWithQualifiers)) + .getOrElse(tableWithQualifiers) } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d82d3edae4e38..6f199cfc5d8cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -931,6 +931,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { $evPrim = $result.copy(); """ } + + override def sql: String = dataType match { + // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this + // type of casting can only be introduced by the analyzer, and can be omitted when converting + // back to SQL query string. + case _: ArrayType | _: MapType | _: StructType => child.sql + case _ => s"CAST(${child.sql} AS ${dataType.sql})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a9c12127d367..d6219514b752b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,6 +224,15 @@ abstract class Expression extends TreeNode[Expression] { protected def toCommentSafeString: String = this.toString .replace("*/", "\\*\\/") .replace("\\u", "\\\\u") + + /** + * Returns SQL representation of this expression. For expressions that don't have a SQL + * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`. + */ + @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation") + def sql: String = throw new UnsupportedOperationException( + s"Cannot map expression $this to its SQL representation" + ) } @@ -356,6 +366,8 @@ abstract class UnaryExpression extends Expression { """ } } + + override def sql: String = s"($prettyName(${child.sql}))" } @@ -456,6 +468,8 @@ abstract class BinaryExpression extends Expression { """ } } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @@ -492,6 +506,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { TypeCheckResult.TypeCheckSuccess } } + + override def sql: String = s"(${left.sql} $symbol ${right.sql})" } @@ -593,4 +609,9 @@ abstract class TernaryExpression extends Expression { """ } } + + override def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index f33833c3918df..827dce8af100e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -49,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic { "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } + override def sql: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index d0b78e15d99d1..94f8801dec369 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -78,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with $countTerm++; """ } + + override def prettyName: String = "monotonically_increasing_id" + + override def sql: String = s"$prettyName()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3add722da7816..1cb1b9da3049b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -24,9 +24,17 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator -abstract sealed class SortDirection -case object Ascending extends SortDirection -case object Descending extends SortDirection +abstract sealed class SortDirection { + def sql: String +} + +case object Ascending extends SortDirection { + override def sql: String = "ASC" +} + +case object Descending extends SortDirection { + override def sql: String = "DESC" +} /** * An expression that can be used to sort a tuple. This class extends expression primarily so that diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b47f32d1768b9..ddd99c51ab0c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -93,11 +94,13 @@ private[sql] case class AggregateExpression( override def prettyString: String = aggregateFunction.prettyString - override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" + + override def sql: String = aggregateFunction.sql(isDistinct) } /** - * AggregateFunction2 is the superclass of two aggregation function interfaces: + * AggregateFunction is the superclass of two aggregation function interfaces: * * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. @@ -163,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } + + def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61a17fd7db0fe..7bd851c059d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp numeric.negate(input) } } + + override def sql: String = s"(-${child.sql})" } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input + + override def sql: String = s"(+${child.sql})" } /** @@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + + override def sql: String = s"$prettyName(${child.sql})" } abstract class BinaryArithmetic extends BinaryOperator { @@ -513,4 +519,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val r = a % n if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c73239f67ff2..5bd97cc7467ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -130,6 +130,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] } }) } + + override def sql: String = child.sql + s".`${childSchema(ordinal).name}`" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f79c8676fb58c..19da849d2bec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} import org.apache.spark.sql.types._ @@ -74,6 +74,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } override def toString: String = s"if ($predicate) $trueValue else $falseValue" + + override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" } trait CaseWhenLike extends Expression { @@ -110,7 +112,7 @@ trait CaseWhenLike extends Expression { override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) } } @@ -206,6 +208,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } // scalastyle:off @@ -310,6 +329,24 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val keySQL = key.sql + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE $keySQL " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 3d65946a1bc65..17f1df06f2fad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { DateTimeUtils.millisToDays(System.currentTimeMillis()) } + + override def prettyName: String = "current_date" } /** @@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { System.currentTimeMillis() * 1000L } + + override def prettyName: String = "current_timestamp" } /** @@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression) s"""${ev.value} = $sd + $d;""" }) } + + override def prettyName: String = "date_add" } /** @@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression) s"""${ev.value} = $sd - $d;""" }) } + + override def prettyName: String = "date_sub" } case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -309,6 +317,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix def this(time: Expression) = { this(time, Literal("yyyy-MM-dd HH:mm:ss")) } + + override def prettyName: String = "to_unix_timestamp" } /** @@ -332,6 +342,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi def this() = { this(CurrentTimestamp()) } + + override def prettyName: String = "unix_timestamp" } abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { @@ -437,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { """ } } + + override def prettyName: String = "unix_time" } /** @@ -451,6 +465,8 @@ case class FromUnixTime(sec: Expression, format: Expression) override def left: Expression = sec override def right: Expression = format + override def prettyName: String = "from_unixtime" + def this(unix: Expression) = { this(unix, Literal("yyyy-MM-dd HH:mm:ss")) } @@ -733,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression) s"""$dtu.dateAddMonths($sd, $m)""" }) } + + override def prettyName: String = "add_months" } /** @@ -758,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression) s"""$dtu.monthsBetween($l, $r)""" }) } + + override def prettyName: String = "months_between" } /** @@ -823,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, d => d) } + + override def prettyName: String = "to_date" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index c54bcdd774021..5f8b544edb511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -73,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def prettyName: String = "promote_precision" + override def sql: String = child.sql } /** @@ -107,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 672cc9c45e0af..17351ef0685a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -214,6 +214,41 @@ case class Literal protected (value: Any, dataType: DataType) } } } + + override def sql: String = (value, dataType) match { + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => + "NULL" + + case _ if value == null => + s"CAST(NULL AS ${dataType.sql})" + + case (v: UTF8String, StringType) => + // Escapes all backslashes and double quotes. + "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" + + case (v: Byte, ByteType) => + s"CAST($v AS ${ByteType.simpleString.toUpperCase})" + + case (v: Short, ShortType) => + s"CAST($v AS ${ShortType.simpleString.toUpperCase})" + + case (v: Long, LongType) => + s"CAST($v AS ${LongType.simpleString.toUpperCase})" + + case (v: Float, FloatType) => + s"CAST($v AS ${FloatType.simpleString.toUpperCase})" + + case (v: Decimal, DecimalType.Fixed(precision, scale)) => + s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))" + + case (v: Int, DateType) => + s"DATE '${DateTimeUtils.toJavaDate(v)}'" + + case (v: Long, TimestampType) => + s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" + + case _ => value.toString + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 002f5929cc26b..66d8631a846ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -70,6 +70,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } + + override def sql: String = s"$name(${child.sql})" } abstract class UnaryLogExpression(f: Double => Double, name: String) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index fd95b124b2455..cc406a39f0408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -220,4 +220,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); """ } + + override def prettyName: String = "hash" + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index eefd9c7482553..eee708cb02f9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)( explicitMetadata == a.explicitMetadata case _ => false } + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"${child.sql} AS $qualifiersString`$name`" + } } /** @@ -271,6 +277,12 @@ case class AttributeReference( // Since the expression id is not in the first constructor it is missing from the default // tree string. override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"$qualifiersString`$name`" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index df4747d4e6f7a..89aec2b20fd0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """ }.mkString("\n") } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -193,6 +195,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.value = eval.isNull eval.code } + + override def sql: String = s"(${child.sql} IS NULL)" } @@ -212,6 +216,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { ev.value = s"(!(${eval.isNull}))" eval.code } + + override def sql: String = s"(${child.sql} IS NOT NULL)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 304b438c84ba4..bca12a8d21023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -101,6 +101,8 @@ case class Not(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } + + override def sql: String = s"(NOT ${child.sql})" } @@ -176,6 +178,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } """ } + + override def sql: String = { + val childrenSQL = children.map(_.sql) + val valueSQL = childrenSQL.head + val listSQL = childrenSQL.tail.mkString(", ") + s"($valueSQL IN ($listSQL))" + } } /** @@ -226,6 +235,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } """ } + + override def sql: String = { + val valueSQL = child.sql + val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + s"($valueSQL IN ($listSQL))" + } } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { @@ -274,6 +289,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } """ } + + override def sql: String = s"(${left.sql} AND ${right.sql})" } @@ -323,6 +340,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } """ } + + override def sql: String = s"(${left.sql} OR ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bde8cb9fe876..8de47e9ddc28d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -49,6 +49,9 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = DoubleType + + // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. + override def sql: String = s"$prettyName($seed)" } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index adef6050c3565..db266639b8560 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -59,6 +59,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes { matches(regex, input1.asInstanceOf[UTF8String].toString) } } + + override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 50c8b9d59847e..931f752b4dc1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -61,6 +62,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -153,6 +156,8 @@ case class ConcatWs(children: Seq[Expression]) """ } } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { @@ -292,24 +297,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") - ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"${termDict} == null" + s"$termDict == null" } else { - s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - ${termLastMatching} = ${matching}.clone(); - ${termLastReplace} = ${replace}.clone(); - ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict(${termLastMatching}, ${termLastReplace}); + $termLastMatching = $matching.clone(); + $termLastReplace = $replace.clone(); + $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatching, $termLastReplace); } - ${ev.value} = ${src}.translate(${termDict}); + ${ev.value} = $src.translate($termDict); """ }) } @@ -340,6 +345,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def dataType: DataType = IntegerType + + override def prettyName: String = "find_in_set" } /** @@ -832,7 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } - } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8b..f8121a733a8d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -37,6 +37,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Compute Current Time", Once, + ComputeCurrentTime) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -333,6 +335,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] { ) Project(cleanedProjection, child) } + + // TODO Eliminate duplicate code + // This clause is identical to the one above except that the inner operator is an `Aggregate` + // rather than a `Project`. + case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliasMap = AttributeMap(projectList2.collect { + case a: Alias => (a.toAttribute, a) + }) + + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.exists(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }.exists(!_.deterministic)) + + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + agg.copy(aggregateExpressions = cleanedProjection) + } } } @@ -976,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 77dec7ca6e2b5..a5f6764aef7ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -37,14 +37,26 @@ object JoinType { } } -sealed abstract class JoinType +sealed abstract class JoinType { + def sql: String +} -case object Inner extends JoinType +case object Inner extends JoinType { + override def sql: String = "INNER" +} -case object LeftOuter extends JoinType +case object LeftOuter extends JoinType { + override def sql: String = "LEFT OUTER" +} -case object RightOuter extends JoinType +case object RightOuter extends JoinType { + override def sql: String = "RIGHT OUTER" +} -case object FullOuter extends JoinType +case object FullOuter extends JoinType { + override def sql: String = "FULL OUTER" +} -case object LeftSemi extends JoinType +case object LeftSemi extends JoinType { + override def sql: String = "LEFT SEMI" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 79759b5a37b34..64957db6b4013 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -423,6 +423,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 62ea731ab5f38..9ebacb4680dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -37,7 +37,7 @@ object RuleExecutor { val maxSize = map.keys.map(_.toString.length).max map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n") + }.mkString("\n", "\n", "") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 71293475ca0f9..7a0d0de6328a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -130,6 +130,20 @@ package object util { ret } + /** + * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`. + */ + def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { + case xs if xs.isEmpty => + Option(Seq.empty[T]) + + case xs => + for { + head <- xs.head + tail <- sequenceOption(xs.tail) + } yield head +: tail + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 6533622492d41..520e344361625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 136a97e066df7..92cf8d4c46bda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -65,6 +65,8 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString + def sql: String = simpleString.toUpperCase + /** * Check if `this` and `other` are the same data type when ignoring nullability * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 00461e529ca0a..5474954af70e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,6 +62,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 34382bf124eb0..3bd733fa2d26c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -25,8 +25,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} -import org.apache.spark.sql.catalyst.util.{LegacyTypeStringParser, DataTypeParser} - +import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser} /** * :: DeveloperApi :: @@ -279,6 +278,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"struct<${fieldTypes.mkString(",")}>" } + override def sql: String = { + val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}") + s"STRUCT<${fieldTypes.mkString(", ")}>" + } + private[sql] override def simpleString(maxNumberFields: Int): String = { val builder = new StringBuilder val fieldTypes = fields.take(maxNumberFields).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 4305903616bd9..d7a2c23be8a9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass + + override def sql: String = sqlType.sql } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index 30978d9b49e2b..d7204c3488313 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -20,17 +20,33 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.plans.PlanTest class CatalystQlSuite extends PlanTest { + val parser = new CatalystQl() test("parse union/except/intersect") { - val paresr = new CatalystQl() - paresr.createPlan("select * from t1 union all select * from t2") - paresr.createPlan("select * from t1 union distinct select * from t2") - paresr.createPlan("select * from t1 union select * from t2") - paresr.createPlan("select * from t1 except select * from t2") - paresr.createPlan("select * from t1 intersect select * from t2") - paresr.createPlan("(select * from t1) union all (select * from t2)") - paresr.createPlan("(select * from t1) union distinct (select * from t2)") - paresr.createPlan("(select * from t1) union (select * from t2)") - paresr.createPlan("select * from ((select * from t1) union (select * from t2)) t") + parser.createPlan("select * from t1 union all select * from t2") + parser.createPlan("select * from t1 union distinct select * from t2") + parser.createPlan("select * from t1 union select * from t2") + parser.createPlan("select * from t1 except select * from t2") + parser.createPlan("select * from t1 intersect select * from t2") + parser.createPlan("(select * from t1) union all (select * from t2)") + parser.createPlan("(select * from t1) union distinct (select * from t2)") + parser.createPlan("(select * from t1) union (select * from t2)") + parser.createPlan("select * from ((select * from t1) union (select * from t2)) t") + } + + test("window function: better support of parentheses") { + parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + + "order by 2) from windowData") + parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + + "order by 2) from windowData") + parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + + "order by 2) from windowData") + + parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + + "from windowData") + parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + + "from windowData") + parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + + "from windowData") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index fa823e3021835..cf84855885a37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -238,43 +237,6 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } - test("analyzer should replace current_timestamp with literals") { - val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), - LocalRelation()) - - val min = System.currentTimeMillis() * 1000 - val plan = in.analyze.asInstanceOf[Project] - val max = (System.currentTimeMillis() + 1) * 1000 - - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - - test("analyzer should replace current_date with literals") { - val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) - - val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val plan = in.analyze.asInstanceOf[Project] - val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) - - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - test("SPARK-12102: Ignore nullablity when comparing two sides of case") { val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala new file mode 100644 index 0000000000000..10ed4e46ddd1c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class ComputeCurrentTimeSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) + } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b998636909a7d..f9f3bd55aa578 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a) - .select('a).analyze + .groupBy('a)('a).analyze comparePlans(optimized, correctAnswer) } @@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a as 'c) - .select('c).analyze + .groupBy('a)('a as 'c).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 21a6fba9078df..2355de3d05865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -165,7 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length - var partsScanned = 0L + var partsScanned = 0 while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. @@ -183,10 +183,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = - sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) partsScanned += p.size diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index a322688a259e2..f3e89ef4a71f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} -import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { /** Check if a command should not be explained. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index be885397a7d40..168b5ab0316d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -22,6 +22,7 @@ import java.util import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -29,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} -import org.apache.spark.{SparkEnv, TaskContext} /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 4f8524f4b967c..fff72872c13b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{IntegerType, StructType, StringType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration @@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer( } } - private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = { - val bucketIdIndex = partitionColumns.length - if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) { - false - } else { - var i = partitionColumns.length - 1 - while (i >= 0) { - val dt = partitionColumns(i).dataType - if (key1.get(i, dt) != key2.get(i, dt)) return false - i -= 1 - } - true - } - } - - private def sortBasedWrite( - sorter: UnsafeKVExternalSorter, - iterator: Iterator[InternalRow], - getSortingKey: UnsafeProjection, - getOutputRow: UnsafeProjection, - getPartitionString: UnsafeProjection, - outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) { - (key1, key2) => key1 != key2 - } else { - (key1, key2) => key1 == null || !sameBucket(key1, key2) - } - - val sortedIterator = sorter.sortedIterator() - var currentKey: UnsafeRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (needNewWriter(currentKey, sortedIterator.getKey)) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey, getPartitionString) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } - } - /** * Open and returns a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the @@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer( } def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] executorSideSetup(taskContext) - var outputWritersCleared = false - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val getSortingKey = - UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema) - - val sortingKeySchema = if (bucketSpec.isEmpty) { - StructType.fromAttributes(partitionColumns) - } else { // If it's bucketed, we should also consider bucket id as part of the key. - val fields = StructType.fromAttributes(partitionColumns) - .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns) - StructType(fields) - } + val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns + + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) @@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // If there is no sorting columns, we set sorter to null and try the hash-based writing first, - // and fill the sorter if there are too many writers and we need to fall back on sorting. - // If there are sorting columns, then we have to sort the data anyway, and no need to try the - // hash-based writing first. - var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { - new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity } else { - null + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) } - while (iterator.hasNext && sorter == null) { - val inputRow = iterator.next() - // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key. - val currentKey = getSortingKey(inputRow) - var currentWriter = outputWriters.get(currentKey) - - if (currentWriter == null) { - if (outputWriters.size < maxOpenFiles) { + + val sortedIterator = sorter.sortedIterator() + var currentKey: UnsafeRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + currentWriter = newOutputWriter(currentKey, getPartitionString) - outputWriters.put(currentKey.copy(), currentWriter) - currentWriter.writeInternal(getOutputRow(inputRow)) - } else { - logInfo(s"Maximum partitions reached, falling back on sorting.") - sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - sorter.insertKV(currentKey, getOutputRow(inputRow)) } - } else { - currentWriter.writeInternal(getOutputRow(inputRow)) - } - } - // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort, or there are sorting columns and we need to sort the whole data set. - if (sorter != null) { - sortBasedWrite( - sorter, - iterator, - getSortingKey, - getOutputRow, - getPartitionString, - outputWriters) + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } } commitTask() @@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer( abortTask() throw new SparkException("Task failed while writing rows.", cause) } - - def clearOutputWriters(): Unit = { - if (!outputWritersCleared) { - outputWriters.asScala.values.foreach(_.close()) - outputWriters.clear() - outputWritersCleared = true - } - } - - def commitTask(): Unit = { - try { - clearOutputWriters() - super.commitTask() - } catch { - case cause: Throwable => - throw new RuntimeException("Failed to commit task", cause) - } - } - - def abortTask(): Unit = { - try { - clearOutputWriters() - } finally { - super.abortTask() - } - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 82287c8967134..9976829638d70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.mapreduce.TaskAttemptContext + import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} +import org.apache.spark.sql.sources.{HadoopFsRelation, HadoopFsRelationProvider, OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 2e3fe3da15389..b2f5c1e96421d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -90,7 +90,7 @@ object JacksonParser { DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => - parser.getLongValue * 1000L + parser.getLongValue * 1000000L case (_, StringType) => val writer = new ByteArrayOutputStream() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 4b375de05e9e3..7754edc803d10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -44,9 +44,9 @@ import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -147,6 +147,12 @@ private[sql] class ParquetRelation( .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) + // If this relation is converted from a Hive metastore table, this method returns the name of the + // original Hive metastore table. + private[sql] def metastoreTableName: Option[TableIdentifier] = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier) + } + private lazy val metadataCache: MetadataCache = { val meta = new MetadataCache meta.refresh() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index d484403d1c641..1c773e69275db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c35f33132f602..9f3607369c30f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -162,7 +162,6 @@ trait HadoopFsRelationProvider { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation - // TODO: expose bucket API to users. private[sql] def createRelation( sqlContext: SQLContext, paths: Array[String], @@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable { dataSchema: StructType, context: TaskAttemptContext): OutputWriter - // TODO: expose bucket API to users. private[sql] def newInstance( path: String, bucketId: Option[Int], @@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ - // TODO: expose bucket API to users. private[sql] def bucketSpec: Option[BucketSpec] = None private class FileStatusCache { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ade1391ecd74a..983dfbdedeefe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -308,6 +308,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( mapData.toDF().limit(1), mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) + + // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake + checkAnswer( + sqlContext.range(2).limit(2147483638), + Row(0) :: Row(1) :: Nil + ) } test("except") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bd987ae1bb03a..5de0979606b88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2067,16 +2067,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } - - test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") { - val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 ) - rdd.toDF("key").registerTempTable("spark12340") - checkAnswer( - sql("select key from spark12340 limit 2147483638"), - Row(1) :: Row(2) :: Row(3) :: Nil - ) - assert(rdd.take(2147483638).size === 3) - assert(rdd.takeAsync(2147483638).get.size === 3) - } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b3b6b7df0c1d1..4ab148065a476 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -83,9 +83,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber * 1000L)), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong * 1000L)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), @@ -1465,4 +1465,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("Casting long as timestamp") { + withTempTable("jsonTable") { + val schema = (new StructType).add("ts", TimestampType) + val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong) + + jsonDF.registerTempTable("jsonTable") + + checkAnswer( + sql("select ts from jsonTable"), + Row(java.sql.Timestamp.valueOf("2016-01-02 03:04:05")) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index cb61f7eeca0de..a0836058d3c74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -205,6 +205,10 @@ private[json] trait TestJsonData { """{"b": [{"c": {}}]}""" :: """]""" :: Nil) + def timestampAsLong: RDD[String] = + sqlContext.sparkContext.parallelize( + """{"ts":1451732645}""" :: Nil) + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index cab6abde6da23..ae95b50e1ee76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -21,9 +21,9 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Try +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.util.{Benchmark, Utils} -import org.apache.spark.{SparkConf, SparkContext} /** * Benchmark to measure parquet read performance. diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bd1a52e5f3303..afd2f611580fc 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,9 +41,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + def testCases: Seq[(String, File)] = { + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + } override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -68,10 +71,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) + super.afterAll() } /** A list of tests deemed out of scope currently and thus completely disregarded. */ - override def blackList = Seq( + override def blackList: Seq[String] = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", "hook_context_cs", @@ -106,7 +110,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - //"describe_table", + // "describe_table", "describe_comment_nonascii", "create_merge_compressed", @@ -323,7 +327,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { * The set of tests that are believed to be working in catalyst. Tests not on whiteList or * blacklist are implicitly marked as ignored. */ - override def whiteList = Seq( + override def whiteList: Seq[String] = Seq( "add_part_exist", "add_part_multiple", "add_partition_no_whitelist", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 98bbdf0653c2a..bad3ca6da231f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.reset() + super.afterAll() } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index bf3fe12d5c5d2..d1b1c0d8d8bc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -24,11 +24,12 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} +import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.parse.EximUtil import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions._ @@ -38,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, AnalyzeTable, DropTable, HiveScriptIOSchema} +import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} import org.apache.spark.sql.types._ import org.apache.spark.sql.AnalysisException @@ -668,7 +669,8 @@ private[hive] object HiveQl extends SparkQl with Logging { Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) + HiveGenericUDTF( + functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) case other => super.nodeToGenerator(node) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala new file mode 100644 index 0000000000000..61e3f183bb42d --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation + +/** + * A builder class used to convert a resolved logical plan into a SQL query string. Note that this + * all resolved logical plan are convertible. They either don't have corresponding SQL + * representations (e.g. logical plans that operate on local Scala collections), or are simply not + * supported by this builder (yet). + */ +class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) + + def toSQL: Option[String] = { + val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + val maybeSQL = try { + toSQL(canonicalizedPlan) + } catch { case cause: UnsupportedOperationException => + logInfo(s"Failed to build SQL query string because: ${cause.getMessage}") + None + } + + if (maybeSQL.isDefined) { + logDebug( + s"""Built SQL query string successfully from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + |# Built SQL query string: + |${maybeSQL.get} + """.stripMargin) + } else { + logDebug( + s"""Failed to build SQL query string from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + """.stripMargin) + } + + maybeSQL + } + + private def projectToSQL( + projectList: Seq[NamedExpression], + child: LogicalPlan, + isDistinct: Boolean): Option[String] = { + for { + childSQL <- toSQL(child) + listSQL = projectList.map(_.sql).mkString(", ") + maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + distinct = if (isDistinct) " DISTINCT " else " " + } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL" + } + + private def aggregateToSQL( + groupingExprs: Seq[Expression], + aggExprs: Seq[Expression], + child: LogicalPlan): Option[String] = { + val aggSQL = aggExprs.map(_.sql).mkString(", ") + val groupingSQL = groupingExprs.map(_.sql).mkString(", ") + val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " + val maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + + toSQL(child).map { childSQL => + s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" + } + } + + private def toSQL(node: LogicalPlan): Option[String] = node match { + case Distinct(Project(list, child)) => + projectToSQL(list, child, isDistinct = true) + + case Project(list, child) => + projectToSQL(list, child, isDistinct = false) + + case Aggregate(groupingExprs, aggExprs, child) => + aggregateToSQL(groupingExprs, aggExprs, child) + + case Limit(limit, child) => + for { + childSQL <- toSQL(child) + limitSQL = limit.sql + } yield s"$childSQL LIMIT $limitSQL" + + case Filter(condition, child) => + for { + childSQL <- toSQL(child) + whereOrHaving = child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + conditionSQL = condition.sql + } yield s"$childSQL $whereOrHaving $conditionSQL" + + case Union(left, right) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + } yield s"$leftSQL UNION ALL $rightSQL" + + // ParquetRelation converted from Hive metastore table + case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => + // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is + // that, the metastore database name and table name are not always propagated to converted + // `ParquetRelation` instances via data source options. Here we use subquery alias as a + // workaround. + Some(s"`$alias`") + + case Subquery(alias, child) => + toSQL(child).map(childSQL => s"($childSQL) AS $alias") + + case Join(left, right, joinType, condition) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + joinTypeSQL = joinType.sql + conditionSQL = condition.map(" ON " + _.sql).getOrElse("") + } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" + + case MetastoreRelation(database, table, alias) => + val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") + Some(s"`$database`.`$table`$aliasSQL") + + case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) + if orders.map(_.child) == partitionExprs => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL CLUSTER BY $partitionExprsSQL" + + case Sort(orders, global, child) => + for { + childSQL <- toSQL(child) + ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + orderOrSort = if (global) "ORDER" else "SORT" + } yield s"$childSQL $orderOrSort BY $ordersSQL" + + case RepartitionByExpression(partitionExprs, child, _) => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL" + + case OneRowRelation => + Some("") + + case _ => None + } + + object Canonicalizer extends RuleExecutor[LogicalPlan] { + override protected def batches: Seq[Batch] = Seq( + Batch("Canonicalizer", FixedPoint(100), + // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over + // `Aggregate`s to perform type casting. This rule merges these `Project`s into + // `Aggregate`s. + ProjectCollapsing, + + // Used to handle other auxiliary `Project`s added by analyzer (e.g. + // `ResolveAggregateFunctions` rule) + RecoverScopingInfo + ) + ) + + object RecoverScopingInfo extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + // This branch handles aggregate functions within HAVING clauses. For example: + // + // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" + // + // This kind of query results in query plans of the following form because of analysis rule + // `ResolveAggregateFunctions`: + // + // Project ... + // +- Filter ... + // +- Aggregate ... + // +- MetastoreRelation default, src, None + case plan @ Project(_, Filter(_, _: Aggregate)) => + wrapChildWithSubquery(plan) + + case plan @ Project(_, + _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit + ) => plan + + case plan: Project => + wrapChildWithSubquery(plan) + } + + def wrapChildWithSubquery(project: Project): Project = project match { + case Project(projectList, child) => + val alias = SQLBuilder.newSubqueryName + val childAttributes = child.outputSet + val aliasedProjectList = projectList.map(_.transform { + case a: Attribute if childAttributes.contains(a) => + a.withQualifiers(alias :: Nil) + }.asInstanceOf[NamedExpression]) + + Project(aliasedProjectList, Subquery(alias, child)) + } + } + } +} + +object SQLBuilder { + private val nextSubqueryId = new AtomicLong(0) + + private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index b1a6d0ab7df3c..56cab1aee89df 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.hive -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} @@ -32,15 +31,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.Obje import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.{analysis, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ @@ -75,19 +71,19 @@ private[hive] class HiveFunctionRegistry( try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUDF( - new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUDAFFunction( - new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. udtf } else { @@ -137,7 +133,8 @@ private[hive] class HiveFunctionRegistry( } } -private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def deterministic: Boolean = isUDFDeterministic @@ -191,6 +188,8 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -205,7 +204,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp override def get(): AnyRef = wrap(func(), oi, dataType) } -private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def nullable: Boolean = true @@ -257,6 +257,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -271,6 +273,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUDTF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors with CodegenFallback { @@ -336,6 +339,8 @@ private[hive] case class HiveGenericUDTF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -343,6 +348,7 @@ private[hive] case class HiveGenericUDTF( * performance a lot. */ private[hive] case class HiveUDAFFunction( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression], isUDAFBridgeRequired: Boolean = false, @@ -427,5 +433,9 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false override val dataType: DataType = inspectorToDataType(returnInspector) -} + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$name($distinct${children.map(_.sql).mkString(", ")})" + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index d26cb48479066..033746d42f557 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -37,8 +37,8 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.hive.client.ClientWrapper +import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} // SPARK-3729: Test key required to check for initialization errors with config. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index a2d283622ca52..e72a18a716b5c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -21,8 +21,8 @@ import scala.util.Try import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala new file mode 100644 index 0000000000000..3a6eb57add4e3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} + +class ExpressionSQLBuilderSuite extends SQLBuilderTest { + test("literal") { + checkSQL(Literal("foo"), "\"foo\"") + checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") + checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(4: Int), "4") + checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(2.5D), "2.5") + checkSQL( + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), + "TIMESTAMP('2016-01-01 00:00:00.0')") + // TODO tests for decimals + } + + test("binary comparisons") { + checkSQL('a.int === 'b.int, "(`a` = `b`)") + checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") + checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))") + + checkSQL('a.int < 'b.int, "(`a` < `b`)") + checkSQL('a.int <= 'b.int, "(`a` <= `b`)") + checkSQL('a.int > 'b.int, "(`a` > `b`)") + checkSQL('a.int >= 'b.int, "(`a` >= `b`)") + + checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))") + checkSQL('a.int in (1, 2), "(`a` IN (1, 2))") + + checkSQL('a.int.isNull, "(`a` IS NULL)") + checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)") + } + + test("logical operators") { + checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)") + checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)") + checkSQL(!'a.boolean, "(NOT `a`)") + checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))") + } + + test("arithmetic expressions") { + checkSQL('a.int + 'b.int, "(`a` + `b`)") + checkSQL('a.int - 'b.int, "(`a` - `b`)") + checkSQL('a.int * 'b.int, "(`a` * `b`)") + checkSQL('a.int / 'b.int, "(`a` / `b`)") + checkSQL('a.int % 'b.int, "(`a` % `b`)") + + checkSQL(-'a.int, "(-`a`)") + checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala new file mode 100644 index 0000000000000..9a8a9c51183da --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils + +class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sqlContext.range(10).write.saveAsTable("t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("t1") + + sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + } + + override protected def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + } + + private def checkHiveQl(hiveQl: String): Unit = { + val df = sql(hiveQl) + val convertedSQL = new SQLBuilder(df).toSQL + + if (convertedSQL.isEmpty) { + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) + } + + val sqlString = convertedSQL.get + try { + checkAnswer(sql(sqlString), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$sqlString + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, + cause) + } + } + + test("in") { + checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)") + } + + test("aggregate function in having clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0") + } + + test("aggregate function in order by clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") + } + + // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule + // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into + // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query + // execution since these aliases have different expression ID. But this introduces name collision + // when converting resolved plans back to SQL query strings as expression IDs are stripped. + ignore("aggregate function in order by clause with multiple order keys") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") + } + + test("type widening in union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") + } + + test("case") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") + } + + test("case with else") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0") + } + + test("case with key") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0") + } + + test("case with key and else") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0") + } + + test("select distinct without aggregate functions") { + checkHiveQl("SELECT DISTINCT id FROM t0") + } + + test("cluster by") { + checkHiveQl("SELECT id FROM t0 CLUSTER BY id") + } + + test("distribute by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id") + } + + test("distribute by with sort by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id") + } + + test("distinct aggregation") { + checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0") + } + + // TODO Enable this + // Query plans transformed by DistinctAggregationRewriter are not recognized yet + ignore("distinct and non-distinct aggregation") { + checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala new file mode 100644 index 0000000000000..a5e209ac9db3b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.hive.test.TestHiveSingleton + +abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val actualSQL = e.sql + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + + protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { + val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL + + if (maybeSQL.isEmpty) { + fail( + s"""Cannot convert the following logical query plan to SQL: + | + |${plan.treeString} + """.stripMargin) + } + + val actualSQL = maybeSQL.get + + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following logical query plan: + | + |${plan.treeString} + | + |$cause + """.stripMargin) + } + + checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) + } + + protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { + checkSQL(df.queryExecution.analyzed, expectedSQL) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index d7e8ebc8d312f..fd3339a66bec0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand +import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} import org.apache.spark.sql.hive.test.TestHive /** @@ -130,6 +131,28 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } + /** Used for testing [[SQLBuilder]] */ + private var numConvertibleQueries: Int = 0 + private var numTotalQueries: Int = 0 + + override protected def afterAll(): Unit = { + logInfo({ + val percentage = if (numTotalQueries > 0) { + numConvertibleQueries.toDouble / numTotalQueries * 100 + } else { + 0D + } + + s"""SQLBuiler statistics: + |- Total query number: $numTotalQueries + |- Number of convertible queries: $numConvertibleQueries + |- Percentage of convertible queries: $percentage% + """.stripMargin + }) + + super.afterAll() + } + protected def prepareAnswer( hiveQuery: TestHive.type#QueryExecution, answer: Seq[String]): Seq[String] = { @@ -372,8 +395,49 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.QueryExecution(queryString) - try { (query, prepareAnswer(query, query.stringResult())) } catch { + var query: TestHive.QueryExecution = null + try { + query = { + val originalQuery = new TestHive.QueryExecution(queryString) + val containsCommands = originalQuery.analyzed.collectFirst { + case _: Command => () + case _: LogicalInsertIntoHiveTable => () + }.nonEmpty + + if (containsCommands) { + originalQuery + } else { + numTotalQueries += 1 + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + numConvertibleQueries += 1 + logInfo( + s""" + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | + |Generated SQL: + |$sql + |}}} + """.stripMargin.trim) + new TestHive.QueryExecution(sql) + }.getOrElse { + logInfo( + s""" + |### Cannot convert the following logical plan back to SQL {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + |}}} + """.stripMargin.trim) + originalQuery + } + } + } + + (query, prepareAnswer(query, query.stringResult())) + } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fa99289b41971..4659d745fe78b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION udtf_count2") + super.afterAll() } test("SPARK-4908: concurrent hive native commands") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 579da0291f291..7f1745705aaaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{AnalysisException, QueryTest} class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b186d297610e2..86f01d2168729 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -27,8 +27,8 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.Utils import org.apache.spark.streaming.scheduler.JobGenerator +import org.apache.spark.util.Utils private[streaming] class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 4e5baebaae04b..4ccc905b275d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -25,7 +25,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge} +import org.apache.spark.serializer.{KryoInputObjectInputBridge, KryoOutputObjectOutputBridge} import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ import org.apache.spark.util.collection.OpenHashMap diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index ea32bbf95ce59..da0430e263b5f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.rdd.MapWithStateRDDRecord - import scala.collection.{immutable, mutable, Map} import scala.reflect.ClassTag import scala.util.Random import com.esotericsoftware.kryo.{Kryo, KryoSerializable} -import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer._ +import org.apache.spark.streaming.rdd.MapWithStateRDDRecord import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} class StateMapSuite extends SparkFunSuite {