diff --git a/.gitignore b/.gitignore
index 07524bc429e92..8ecf536e79a5f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -60,7 +60,6 @@ dev/create-release/*final
spark-*-bin-*.tgz
unit-tests.log
/lib/
-ec2/lib/
rat-results.txt
scalastyle.txt
scalastyle-output.xml
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index beacc39500aaa..34be7f0ebd752 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -130,6 +130,7 @@ exportMethods("%in%",
"count",
"countDistinct",
"crc32",
+ "hash",
"cume_dist",
"date_add",
"date_format",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index df36bc869acb4..9bb7876b384ce 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -340,6 +340,26 @@ setMethod("crc32",
column(jc)
})
+#' hash
+#'
+#' Calculates the hash code of given columns, and returns the result as a int column.
+#'
+#' @rdname hash
+#' @name hash
+#' @family misc_funcs
+#' @export
+#' @examples \dontrun{hash(df$c)}
+setMethod("hash",
+ signature(x = "Column"),
+ function(x, ...) {
+ jcols <- lapply(list(x, ...), function (x) {
+ stopifnot(class(x) == "Column")
+ x@jc
+ })
+ jc <- callJStatic("org.apache.spark.sql.functions", "hash", jcols)
+ column(jc)
+ })
+
#' dayofmonth
#'
#' Extracts the day of the month as an integer from a given date/timestamp/string.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index ba6861709754d..5ba68e3a4f378 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -736,6 +736,10 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct")
#' @export
setGeneric("crc32", function(x) { standardGeneric("crc32") })
+#' @rdname hash
+#' @export
+setGeneric("hash", function(x, ...) { standardGeneric("hash") })
+
#' @rdname cume_dist
#' @export
setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") })
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index eaf60beda3473..97625b94a0e23 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -922,7 +922,7 @@ test_that("column functions", {
c <- column("a")
c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c)
c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c)
- c3 <- cosh(c) + count(c) + crc32(c) + exp(c)
+ c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c)
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c)
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 94719a4572ef6..7de9df1e489fb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -77,7 +77,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
This implementation is non-blocking, asynchronously handling the
results of each job and triggering the next job using callbacks on futures.
*/
- def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] =
+ def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] =
if (results.size >= num || partsScanned >= totalParts) {
Future.successful(results.toSeq)
} else {
@@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
}
val left = num - results.size
- val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val buf = new Array[Array[T]](p.size)
self.context.setCallSite(callSite)
@@ -109,13 +109,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)
- job.flatMap {_ =>
+ job.flatMap { _ =>
buf.foreach(results ++= _.take(num - results.size))
continue(partsScanned + p.size)
}
}
- new ComplexFutureAction[Seq[T]](continue(0L)(_))
+ new ComplexFutureAction[Seq[T]](continue(0)(_))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e25657cc109be..de7102f5b6245 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag](
} else {
val buf = new ArrayBuffer[T]
val totalParts = this.partitions.length
- var partsScanned = 0L
+ var partsScanned = 0
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
@@ -1209,7 +1209,7 @@ abstract class RDD[T: ClassTag](
}
val left = num - buf.size
- val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(num - buf.size))
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 24acbed4d7258..ef2ed445005d3 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -482,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.parallelize(1 to 2, 2)
+ assert(nums.take(2147483638).size === 2)
+ assert(nums.takeAsync(2147483638).get.size === 2)
}
test("top with predefined ordering") {
diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh
index b0a3374becc6a..d404939d1caee 100755
--- a/dev/create-release/release-tag.sh
+++ b/dev/create-release/release-tag.sh
@@ -64,9 +64,6 @@ git commit -a -m "Preparing Spark release $RELEASE_TAG"
echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH"
git tag $RELEASE_TAG
-# TODO: It would be nice to do some verifications here
-# i.e. check whether ec2 scripts have the new version
-
# Create next version
$MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs
git commit -a -m "Preparing development version $NEXT_VERSION"
diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py
index 7f152b7f53559..5d0ac16b3b0a1 100755
--- a/dev/create-release/releaseutils.py
+++ b/dev/create-release/releaseutils.py
@@ -159,7 +159,6 @@ def get_commits(tag):
"build": CORE_COMPONENT,
"deploy": CORE_COMPONENT,
"documentation": CORE_COMPONENT,
- "ec2": "EC2",
"examples": CORE_COMPONENT,
"graphx": "GraphX",
"input/output": CORE_COMPONENT,
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index e4373f79f7922..cd3ff293502ae 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -84,13 +84,13 @@ hadoop-yarn-server-web-proxy-2.2.0.jar
httpclient-4.3.2.jar
httpcore-4.3.2.jar
ivy-2.4.0.jar
-jackson-annotations-2.4.4.jar
-jackson-core-2.4.4.jar
+jackson-annotations-2.5.3.jar
+jackson-core-2.5.3.jar
jackson-core-asl-1.9.13.jar
-jackson-databind-2.4.4.jar
+jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.4.4.jar
+jackson-module-scala_2.10-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
jansi-1.4.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index 7478181406d07..0985089ccea61 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -79,13 +79,13 @@ hadoop-yarn-server-web-proxy-2.3.0.jar
httpclient-4.3.2.jar
httpcore-4.3.2.jar
ivy-2.4.0.jar
-jackson-annotations-2.4.4.jar
-jackson-core-2.4.4.jar
+jackson-annotations-2.5.3.jar
+jackson-core-2.5.3.jar
jackson-core-asl-1.9.13.jar
-jackson-databind-2.4.4.jar
+jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.4.4.jar
+jackson-module-scala_2.10-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
jansi-1.4.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index faffb8bf398a5..50f062601c02b 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -79,13 +79,13 @@ hadoop-yarn-server-web-proxy-2.4.0.jar
httpclient-4.3.2.jar
httpcore-4.3.2.jar
ivy-2.4.0.jar
-jackson-annotations-2.4.4.jar
-jackson-core-2.4.4.jar
+jackson-annotations-2.5.3.jar
+jackson-core-2.5.3.jar
jackson-core-asl-1.9.13.jar
-jackson-databind-2.4.4.jar
+jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.4.4.jar
+jackson-module-scala_2.10-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
jansi-1.4.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index e703c7acd3876..2b6ca983ad65e 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -85,13 +85,13 @@ htrace-core-3.0.4.jar
httpclient-4.3.2.jar
httpcore-4.3.2.jar
ivy-2.4.0.jar
-jackson-annotations-2.4.4.jar
-jackson-core-2.4.4.jar
+jackson-annotations-2.5.3.jar
+jackson-core-2.5.3.jar
jackson-core-asl-1.9.13.jar
-jackson-databind-2.4.4.jar
+jackson-databind-2.5.3.jar
jackson-jaxrs-1.9.13.jar
jackson-mapper-asl-1.9.13.jar
-jackson-module-scala_2.10-2.4.4.jar
+jackson-module-scala_2.10-2.5.3.jar
jackson-xc-1.9.13.jar
janino-2.7.8.jar
jansi-1.4.jar
diff --git a/dev/lint-python b/dev/lint-python
index 0b97213ae3dff..1765a07d2f22b 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -19,7 +19,7 @@
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")"
-PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport"
+PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport"
PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py"
PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt"
PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt"
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 47cd600bd18a4..1fc6596164124 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -406,15 +406,6 @@ def contains_file(self, filename):
should_run_build_tests=True
)
-ec2 = Module(
- name="ec2",
- dependencies=[],
- source_file_regexes=[
- "ec2/",
- ]
-)
-
-
yarn = Module(
name="yarn",
dependencies=[],
diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh
index 424ce6ad7663c..def87aa4087e3 100755
--- a/dev/test-dependencies.sh
+++ b/dev/test-dependencies.sh
@@ -70,19 +70,10 @@ $MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /de
# Generate manifests for each Hadoop profile:
for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do
echo "Performing Maven install for $HADOOP_PROFILE"
- $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar install:install -q \
- -pl '!assembly' \
- -pl '!examples' \
- -pl '!external/flume-assembly' \
- -pl '!external/kafka-assembly' \
- -pl '!external/twitter' \
- -pl '!external/flume' \
- -pl '!external/mqtt' \
- -pl '!external/mqtt-assembly' \
- -pl '!external/zeromq' \
- -pl '!external/kafka' \
- -pl '!tags' \
- -DskipTests
+ $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install -q
+
+ echo "Performing Maven validate for $HADOOP_PROFILE"
+ $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE validate -q
echo "Generating dependency manifest for $HADOOP_PROFILE"
mkdir -p dev/pr-deps
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 62d75eff71057..d493f62f0e578 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -98,8 +98,6 @@
Spark Standalone
Mesos
YARN
-
- Amazon EC2
diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md
index faaf154d243f5..2810112f5294e 100644
--- a/docs/cluster-overview.md
+++ b/docs/cluster-overview.md
@@ -53,8 +53,6 @@ The system currently supports three cluster managers:
and service applications.
* [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2.
-In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone
-cluster on Amazon EC2.
# Submitting Applications
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
deleted file mode 100644
index 7f60f82b966fe..0000000000000
--- a/docs/ec2-scripts.md
+++ /dev/null
@@ -1,192 +0,0 @@
----
-layout: global
-title: Running Spark on EC2
----
-
-The `spark-ec2` script, located in Spark's `ec2` directory, allows you
-to launch, manage and shut down Spark clusters on Amazon EC2. It automatically
-sets up Spark and HDFS on the cluster for you. This guide describes
-how to use `spark-ec2` to launch clusters, how to run jobs on them, and how
-to shut them down. It assumes you've already signed up for an EC2 account
-on the [Amazon Web Services site](http://aws.amazon.com/).
-
-`spark-ec2` is designed to manage multiple named clusters. You can
-launch a new cluster (telling the script its size and giving it a name),
-shutdown an existing cluster, or log into a cluster. Each cluster is
-identified by placing its machines into EC2 security groups whose names
-are derived from the name of the cluster. For example, a cluster named
-`test` will contain a master node in a security group called
-`test-master`, and a number of slave nodes in a security group called
-`test-slaves`. The `spark-ec2` script will create these security groups
-for you based on the cluster name you request. You can also use them to
-identify machines belonging to each cluster in the Amazon EC2 Console.
-
-
-# Before You Start
-
-- Create an Amazon EC2 key pair for yourself. This can be done by
- logging into your Amazon Web Services account through the [AWS
- console](http://aws.amazon.com/console/), clicking Key Pairs on the
- left sidebar, and creating and downloading a key. Make sure that you
- set the permissions for the private key file to `600` (i.e. only you
- can read and write it) so that `ssh` will work.
-- Whenever you want to use the `spark-ec2` script, set the environment
- variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to your
- Amazon EC2 access key ID and secret access key. These can be
- obtained from the [AWS homepage](http://aws.amazon.com/) by clicking
- Account \> Security Credentials \> Access Credentials.
-
-# Launching a Cluster
-
-- Go into the `ec2` directory in the release of Spark you downloaded.
-- Run
- `./spark-ec2 -k -i -s launch `,
- where `` is the name of your EC2 key pair (that you gave it
- when you created it), `` is the private key file for your
- key pair, `` is the number of slave nodes to launch (try
- 1 at first), and `` is the name to give to your
- cluster.
-
- For example:
-
- ```bash
- export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU
-export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123
-./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster
- ```
-
-- After everything launches, check that the cluster scheduler is up and sees
- all the slaves by going to its web UI, which will be printed at the end of
- the script (typically `http://:8080`).
-
-You can also run `./spark-ec2 --help` to see more usage options. The
-following options are worth pointing out:
-
-- `--instance-type=` can be used to specify an EC2
-instance type to use. For now, the script only supports 64-bit instance
-types, and the default type is `m1.large` (which has 2 cores and 7.5 GB
-RAM). Refer to the Amazon pages about [EC2 instance
-types](http://aws.amazon.com/ec2/instance-types) and [EC2
-pricing](http://aws.amazon.com/ec2/#pricing) for information about other
-instance types.
-- `--region=` specifies an EC2 region in which to launch
-instances. The default region is `us-east-1`.
-- `--zone=` can be used to specify an EC2 availability zone
-to launch instances in. Sometimes, you will get an error because there
-is not enough capacity in one zone, and you should try to launch in
-another.
-- `--ebs-vol-size=` will attach an EBS volume with a given amount
- of space to each node so that you can have a persistent HDFS cluster
- on your nodes across cluster restarts (see below).
-- `--spot-price=` will launch the worker nodes as
- [Spot Instances](http://aws.amazon.com/ec2/spot-instances/),
- bidding for the given maximum price (in dollars).
-- `--spark-version=` will pre-load the cluster with the
- specified version of Spark. The `` can be a version number
- (e.g. "0.7.3") or a specific git hash. By default, a recent
- version will be used.
-- `--spark-git-repo=` will let you run a custom version of
- Spark that is built from the given git repository. By default, the
- [Apache Github mirror](https://github.com/apache/spark) will be used.
- When using a custom Spark version, `--spark-version` must be set to git
- commit hash, such as 317e114, instead of a version number.
-- If one of your launches fails due to e.g. not having the right
-permissions on your private key file, you can run `launch` with the
-`--resume` option to restart the setup process on an existing cluster.
-
-# Launching a Cluster in a VPC
-
-- Run
- `./spark-ec2 -k -i -s --vpc-id= --subnet-id= launch `,
- where `` is the name of your EC2 key pair (that you gave it
- when you created it), `` is the private key file for your
- key pair, `` is the number of slave nodes to launch (try
- 1 at first), `` is the name of your VPC, `` is the
- name of your subnet, and `` is the name to give to your
- cluster.
-
- For example:
-
- ```bash
- export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU
-export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123
-./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --vpc-id=vpc-a28d24c7 --subnet-id=subnet-4eb27b39 --spark-version=1.1.0 launch my-spark-cluster
- ```
-
-# Running Applications
-
-- Go into the `ec2` directory in the release of Spark you downloaded.
-- Run `./spark-ec2 -k -i login ` to
- SSH into the cluster, where `` and `` are as
- above. (This is just for convenience; you could also use
- the EC2 console.)
-- To deploy code or data within your cluster, you can log in and use the
- provided script `~/spark-ec2/copy-dir`, which,
- given a directory path, RSYNCs it to the same location on all the slaves.
-- If your application needs to access large datasets, the fastest way to do
- that is to load them from Amazon S3 or an Amazon EBS device into an
- instance of the Hadoop Distributed File System (HDFS) on your nodes.
- The `spark-ec2` script already sets up a HDFS instance for you. It's
- installed in `/root/ephemeral-hdfs`, and can be accessed using the
- `bin/hadoop` script in that directory. Note that the data in this
- HDFS goes away when you stop and restart a machine.
-- There is also a *persistent HDFS* instance in
- `/root/persistent-hdfs` that will keep data across cluster restarts.
- Typically each node has relatively little space of persistent data
- (about 3 GB), but you can use the `--ebs-vol-size` option to
- `spark-ec2` to attach a persistent EBS volume to each node for
- storing the persistent HDFS.
-- Finally, if you get errors while running your application, look at the slave's logs
- for that application inside of the scheduler work directory (/root/spark/work). You can
- also view the status of the cluster using the web UI: `http://:8080`.
-
-# Configuration
-
-You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such
-as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to
-do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master,
-then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers.
-
-The [configuration guide](configuration.html) describes the available configuration options.
-
-# Terminating a Cluster
-
-***Note that there is no way to recover data on EC2 nodes after shutting
-them down! Make sure you have copied everything important off the nodes
-before stopping them.***
-
-- Go into the `ec2` directory in the release of Spark you downloaded.
-- Run `./spark-ec2 destroy `.
-
-# Pausing and Restarting Clusters
-
-The `spark-ec2` script also supports pausing a cluster. In this case,
-the VMs are stopped but not terminated, so they
-***lose all data on ephemeral disks*** but keep the data in their
-root partitions and their `persistent-hdfs`. Stopped machines will not
-cost you any EC2 cycles, but ***will*** continue to cost money for EBS
-storage.
-
-- To stop one of your clusters, go into the `ec2` directory and run
-`./spark-ec2 --region= stop `.
-- To restart it later, run
-`./spark-ec2 -i --region= start `.
-- To ultimately destroy the cluster and stop consuming EBS space, run
-`./spark-ec2 --region= destroy ` as described in the previous
-section.
-
-# Limitations
-
-- Support for "cluster compute" nodes is limited -- there's no way to specify a
- locality group. However, you can launch slave nodes in your
- `-slaves` group manually and then use `spark-ec2 launch
- --resume` to start a cluster with them.
-
-If you have a patch or suggestion for one of these limitations, feel free to
-[contribute](contributing-to-spark.html) it!
-
-# Accessing Data in S3
-
-Spark's file interface allows it to process data in Amazon S3 using the same URI formats that are supported for Hadoop. You can specify a path in S3 as input through a URI of the form `s3n:///path`. To provide AWS credentials for S3 access, launch the Spark cluster with the option `--copy-aws-credentials`. Full instructions on S3 access using the Hadoop input libraries can be found on the [Hadoop S3 page](http://wiki.apache.org/hadoop/AmazonS3).
-
-In addition to using a single input file, you can also use a directory of files as input by simply giving the path to the directory.
diff --git a/docs/index.md b/docs/index.md
index ae26f97c86c21..9dfc52a2bdc9b 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -64,7 +64,7 @@ To run Spark interactively in a R interpreter, use `bin/sparkR`:
./bin/sparkR --master local[2]
Example applications are also provided in R. For example,
-
+
./bin/spark-submit examples/src/main/r/dataframe.R
# Launching on a Cluster
@@ -73,7 +73,6 @@ The Spark [cluster mode overview](cluster-overview.html) explains the key concep
Spark can run both by itself, or over several existing cluster managers. It currently provides several
options for deployment:
-* [Amazon EC2](ec2-scripts.html): our EC2 scripts let you launch a cluster in about 5 minutes
* [Standalone Deploy Mode](spark-standalone.html): simplest way to deploy Spark on a private cluster
* [Apache Mesos](running-on-mesos.html)
* [Hadoop YARN](running-on-yarn.html)
@@ -103,7 +102,7 @@ options for deployment:
* [Cluster Overview](cluster-overview.html): overview of concepts and components when running on a cluster
* [Submitting Applications](submitting-applications.html): packaging and deploying applications
* Deployment modes:
- * [Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes
+ * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes
* [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager
* [Mesos](running-on-mesos.html): deploy a private cluster using
[Apache Mesos](http://mesos.apache.org)
diff --git a/ec2/README b/ec2/README
deleted file mode 100644
index 72434f24bf98d..0000000000000
--- a/ec2/README
+++ /dev/null
@@ -1,4 +0,0 @@
-This folder contains a script, spark-ec2, for launching Spark clusters on
-Amazon EC2. Usage instructions are available online at:
-
-http://spark.apache.org/docs/latest/ec2-scripts.html
diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh
deleted file mode 100644
index 4f3e8da809f7f..0000000000000
--- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh
+++ /dev/null
@@ -1,34 +0,0 @@
-#!/usr/bin/env bash
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# These variables are automatically filled in by the spark-ec2 script.
-export MASTERS="{{master_list}}"
-export SLAVES="{{slave_list}}"
-export HDFS_DATA_DIRS="{{hdfs_data_dirs}}"
-export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}"
-export SPARK_LOCAL_DIRS="{{spark_local_dirs}}"
-export MODULES="{{modules}}"
-export SPARK_VERSION="{{spark_version}}"
-export TACHYON_VERSION="{{tachyon_version}}"
-export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}"
-export SWAP_MB="{{swap}}"
-export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}"
-export SPARK_MASTER_OPTS="{{spark_master_opts}}"
-export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}"
-export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}"
diff --git a/ec2/spark-ec2 b/ec2/spark-ec2
deleted file mode 100755
index 26e7d22655694..0000000000000
--- a/ec2/spark-ec2
+++ /dev/null
@@ -1,25 +0,0 @@
-#!/bin/sh
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Preserve the user's CWD so that relative paths are passed correctly to
-#+ the underlying Python script.
-SPARK_EC2_DIR="$(dirname "$0")"
-
-python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@"
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
deleted file mode 100755
index 19d5980560fef..0000000000000
--- a/ec2/spark_ec2.py
+++ /dev/null
@@ -1,1530 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from __future__ import division, print_function, with_statement
-
-import codecs
-import hashlib
-import itertools
-import logging
-import os
-import os.path
-import pipes
-import random
-import shutil
-import string
-from stat import S_IRUSR
-import subprocess
-import sys
-import tarfile
-import tempfile
-import textwrap
-import time
-import warnings
-from datetime import datetime
-from optparse import OptionParser
-from sys import stderr
-
-if sys.version < "3":
- from urllib2 import urlopen, Request, HTTPError
-else:
- from urllib.request import urlopen, Request
- from urllib.error import HTTPError
- raw_input = input
- xrange = range
-
-SPARK_EC2_VERSION = "1.6.0"
-SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
-
-VALID_SPARK_VERSIONS = set([
- "0.7.3",
- "0.8.0",
- "0.8.1",
- "0.9.0",
- "0.9.1",
- "0.9.2",
- "1.0.0",
- "1.0.1",
- "1.0.2",
- "1.1.0",
- "1.1.1",
- "1.2.0",
- "1.2.1",
- "1.3.0",
- "1.3.1",
- "1.4.0",
- "1.4.1",
- "1.5.0",
- "1.5.1",
- "1.5.2",
- "1.6.0",
-])
-
-SPARK_TACHYON_MAP = {
- "1.0.0": "0.4.1",
- "1.0.1": "0.4.1",
- "1.0.2": "0.4.1",
- "1.1.0": "0.5.0",
- "1.1.1": "0.5.0",
- "1.2.0": "0.5.0",
- "1.2.1": "0.5.0",
- "1.3.0": "0.5.0",
- "1.3.1": "0.5.0",
- "1.4.0": "0.6.4",
- "1.4.1": "0.6.4",
- "1.5.0": "0.7.1",
- "1.5.1": "0.7.1",
- "1.5.2": "0.7.1",
- "1.6.0": "0.8.2",
-}
-
-DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION
-DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark"
-
-# Default location to get the spark-ec2 scripts (and ami-list) from
-DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2"
-DEFAULT_SPARK_EC2_BRANCH = "branch-1.5"
-
-
-def setup_external_libs(libs):
- """
- Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH.
- """
- PYPI_URL_PREFIX = "https://pypi.python.org/packages/source"
- SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib")
-
- if not os.path.exists(SPARK_EC2_LIB_DIR):
- print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format(
- path=SPARK_EC2_LIB_DIR
- ))
- print("This should be a one-time operation.")
- os.mkdir(SPARK_EC2_LIB_DIR)
-
- for lib in libs:
- versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"])
- lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name)
-
- if not os.path.isdir(lib_dir):
- tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz")
- print(" - Downloading {lib}...".format(lib=lib["name"]))
- download_stream = urlopen(
- "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format(
- prefix=PYPI_URL_PREFIX,
- first_letter=lib["name"][:1],
- lib_name=lib["name"],
- lib_version=lib["version"]
- )
- )
- with open(tgz_file_path, "wb") as tgz_file:
- tgz_file.write(download_stream.read())
- with open(tgz_file_path, "rb") as tar:
- if hashlib.md5(tar.read()).hexdigest() != lib["md5"]:
- print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr)
- sys.exit(1)
- tar = tarfile.open(tgz_file_path)
- tar.extractall(path=SPARK_EC2_LIB_DIR)
- tar.close()
- os.remove(tgz_file_path)
- print(" - Finished downloading {lib}.".format(lib=lib["name"]))
- sys.path.insert(1, lib_dir)
-
-
-# Only PyPI libraries are supported.
-external_libs = [
- {
- "name": "boto",
- "version": "2.34.0",
- "md5": "5556223d2d0cc4d06dd4829e671dcecd"
- }
-]
-
-setup_external_libs(external_libs)
-
-import boto
-from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType
-from boto import ec2
-
-
-class UsageError(Exception):
- pass
-
-
-# Configure and parse our command-line arguments
-def parse_args():
- parser = OptionParser(
- prog="spark-ec2",
- version="%prog {v}".format(v=SPARK_EC2_VERSION),
- usage="%prog [options] \n\n"
- + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves")
-
- parser.add_option(
- "-s", "--slaves", type="int", default=1,
- help="Number of slaves to launch (default: %default)")
- parser.add_option(
- "-w", "--wait", type="int",
- help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start")
- parser.add_option(
- "-k", "--key-pair",
- help="Key pair to use on instances")
- parser.add_option(
- "-i", "--identity-file",
- help="SSH private key file to use for logging into instances")
- parser.add_option(
- "-p", "--profile", default=None,
- help="If you have multiple profiles (AWS or boto config), you can configure " +
- "additional, named profiles by using this option (default: %default)")
- parser.add_option(
- "-t", "--instance-type", default="m1.large",
- help="Type of instance to launch (default: %default). " +
- "WARNING: must be 64-bit; small instances won't work")
- parser.add_option(
- "-m", "--master-instance-type", default="",
- help="Master instance type (leave empty for same as instance-type)")
- parser.add_option(
- "-r", "--region", default="us-east-1",
- help="EC2 region used to launch instances in, or to find them in (default: %default)")
- parser.add_option(
- "-z", "--zone", default="",
- help="Availability zone to launch instances in, or 'all' to spread " +
- "slaves across multiple (an additional $0.01/Gb for bandwidth" +
- "between zones applies) (default: a single zone chosen at random)")
- parser.add_option(
- "-a", "--ami",
- help="Amazon Machine Image ID to use")
- parser.add_option(
- "-v", "--spark-version", default=DEFAULT_SPARK_VERSION,
- help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)")
- parser.add_option(
- "--spark-git-repo",
- default=DEFAULT_SPARK_GITHUB_REPO,
- help="Github repo from which to checkout supplied commit hash (default: %default)")
- parser.add_option(
- "--spark-ec2-git-repo",
- default=DEFAULT_SPARK_EC2_GITHUB_REPO,
- help="Github repo from which to checkout spark-ec2 (default: %default)")
- parser.add_option(
- "--spark-ec2-git-branch",
- default=DEFAULT_SPARK_EC2_BRANCH,
- help="Github repo branch of spark-ec2 to use (default: %default)")
- parser.add_option(
- "--deploy-root-dir",
- default=None,
- help="A directory to copy into / on the first master. " +
- "Must be absolute. Note that a trailing slash is handled as per rsync: " +
- "If you omit it, the last directory of the --deploy-root-dir path will be created " +
- "in / before copying its contents. If you append the trailing slash, " +
- "the directory is not created and its contents are copied directly into /. " +
- "(default: %default).")
- parser.add_option(
- "--hadoop-major-version", default="1",
- help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " +
- "(Hadoop 2.4.0) (default: %default)")
- parser.add_option(
- "-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
- help="Use SSH dynamic port forwarding to create a SOCKS proxy at " +
- "the given local address (for use with login)")
- parser.add_option(
- "--resume", action="store_true", default=False,
- help="Resume installation on a previously launched cluster " +
- "(for debugging)")
- parser.add_option(
- "--ebs-vol-size", metavar="SIZE", type="int", default=0,
- help="Size (in GB) of each EBS volume.")
- parser.add_option(
- "--ebs-vol-type", default="standard",
- help="EBS volume type (e.g. 'gp2', 'standard').")
- parser.add_option(
- "--ebs-vol-num", type="int", default=1,
- help="Number of EBS volumes to attach to each node as /vol[x]. " +
- "The volumes will be deleted when the instances terminate. " +
- "Only possible on EBS-backed AMIs. " +
- "EBS volumes are only attached if --ebs-vol-size > 0. " +
- "Only support up to 8 EBS volumes.")
- parser.add_option(
- "--placement-group", type="string", default=None,
- help="Which placement group to try and launch " +
- "instances into. Assumes placement group is already " +
- "created.")
- parser.add_option(
- "--swap", metavar="SWAP", type="int", default=1024,
- help="Swap space to set up per node, in MB (default: %default)")
- parser.add_option(
- "--spot-price", metavar="PRICE", type="float",
- help="If specified, launch slaves as spot instances with the given " +
- "maximum price (in dollars)")
- parser.add_option(
- "--ganglia", action="store_true", default=True,
- help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " +
- "the Ganglia page will be publicly accessible")
- parser.add_option(
- "--no-ganglia", action="store_false", dest="ganglia",
- help="Disable Ganglia monitoring for the cluster")
- parser.add_option(
- "-u", "--user", default="root",
- help="The SSH user you want to connect as (default: %default)")
- parser.add_option(
- "--delete-groups", action="store_true", default=False,
- help="When destroying a cluster, delete the security groups that were created")
- parser.add_option(
- "--use-existing-master", action="store_true", default=False,
- help="Launch fresh slaves, but use an existing stopped master if possible")
- parser.add_option(
- "--worker-instances", type="int", default=1,
- help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " +
- "is used as Hadoop major version (default: %default)")
- parser.add_option(
- "--master-opts", type="string", default="",
- help="Extra options to give to master through SPARK_MASTER_OPTS variable " +
- "(e.g -Dspark.worker.timeout=180)")
- parser.add_option(
- "--user-data", type="string", default="",
- help="Path to a user-data file (most AMIs interpret this as an initialization script)")
- parser.add_option(
- "--authorized-address", type="string", default="0.0.0.0/0",
- help="Address to authorize on created security groups (default: %default)")
- parser.add_option(
- "--additional-security-group", type="string", default="",
- help="Additional security group to place the machines in")
- parser.add_option(
- "--additional-tags", type="string", default="",
- help="Additional tags to set on the machines; tags are comma-separated, while name and " +
- "value are colon separated; ex: \"Task:MySparkProject,Env:production\"")
- parser.add_option(
- "--copy-aws-credentials", action="store_true", default=False,
- help="Add AWS credentials to hadoop configuration to allow Spark to access S3")
- parser.add_option(
- "--subnet-id", default=None,
- help="VPC subnet to launch instances in")
- parser.add_option(
- "--vpc-id", default=None,
- help="VPC to launch instances in")
- parser.add_option(
- "--private-ips", action="store_true", default=False,
- help="Use private IPs for instances rather than public if VPC/subnet " +
- "requires that.")
- parser.add_option(
- "--instance-initiated-shutdown-behavior", default="stop",
- choices=["stop", "terminate"],
- help="Whether instances should terminate when shut down or just stop")
- parser.add_option(
- "--instance-profile-name", default=None,
- help="IAM profile name to launch instances under")
-
- (opts, args) = parser.parse_args()
- if len(args) != 2:
- parser.print_help()
- sys.exit(1)
- (action, cluster_name) = args
-
- # Boto config check
- # http://boto.cloudhackers.com/en/latest/boto_config_tut.html
- home_dir = os.getenv('HOME')
- if home_dir is None or not os.path.isfile(home_dir + '/.boto'):
- if not os.path.isfile('/etc/boto.cfg'):
- # If there is no boto config, check aws credentials
- if not os.path.isfile(home_dir + '/.aws/credentials'):
- if os.getenv('AWS_ACCESS_KEY_ID') is None:
- print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set",
- file=stderr)
- sys.exit(1)
- if os.getenv('AWS_SECRET_ACCESS_KEY') is None:
- print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set",
- file=stderr)
- sys.exit(1)
- return (opts, action, cluster_name)
-
-
-# Get the EC2 security group of the given name, creating it if it doesn't exist
-def get_or_make_group(conn, name, vpc_id):
- groups = conn.get_all_security_groups()
- group = [g for g in groups if g.name == name]
- if len(group) > 0:
- return group[0]
- else:
- print("Creating security group " + name)
- return conn.create_security_group(name, "Spark EC2 group", vpc_id)
-
-
-def get_validate_spark_version(version, repo):
- if "." in version:
- version = version.replace("v", "")
- if version not in VALID_SPARK_VERSIONS:
- print("Don't know about Spark version: {v}".format(v=version), file=stderr)
- sys.exit(1)
- return version
- else:
- github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version)
- request = Request(github_commit_url)
- request.get_method = lambda: 'HEAD'
- try:
- response = urlopen(request)
- except HTTPError as e:
- print("Couldn't validate Spark commit: {url}".format(url=github_commit_url),
- file=stderr)
- print("Received HTTP response code of {code}.".format(code=e.code), file=stderr)
- sys.exit(1)
- return version
-
-
-# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/
-# Last Updated: 2015-06-19
-# For easy maintainability, please keep this manually-inputted dictionary sorted by key.
-EC2_INSTANCE_TYPES = {
- "c1.medium": "pvm",
- "c1.xlarge": "pvm",
- "c3.large": "pvm",
- "c3.xlarge": "pvm",
- "c3.2xlarge": "pvm",
- "c3.4xlarge": "pvm",
- "c3.8xlarge": "pvm",
- "c4.large": "hvm",
- "c4.xlarge": "hvm",
- "c4.2xlarge": "hvm",
- "c4.4xlarge": "hvm",
- "c4.8xlarge": "hvm",
- "cc1.4xlarge": "hvm",
- "cc2.8xlarge": "hvm",
- "cg1.4xlarge": "hvm",
- "cr1.8xlarge": "hvm",
- "d2.xlarge": "hvm",
- "d2.2xlarge": "hvm",
- "d2.4xlarge": "hvm",
- "d2.8xlarge": "hvm",
- "g2.2xlarge": "hvm",
- "g2.8xlarge": "hvm",
- "hi1.4xlarge": "pvm",
- "hs1.8xlarge": "pvm",
- "i2.xlarge": "hvm",
- "i2.2xlarge": "hvm",
- "i2.4xlarge": "hvm",
- "i2.8xlarge": "hvm",
- "m1.small": "pvm",
- "m1.medium": "pvm",
- "m1.large": "pvm",
- "m1.xlarge": "pvm",
- "m2.xlarge": "pvm",
- "m2.2xlarge": "pvm",
- "m2.4xlarge": "pvm",
- "m3.medium": "hvm",
- "m3.large": "hvm",
- "m3.xlarge": "hvm",
- "m3.2xlarge": "hvm",
- "m4.large": "hvm",
- "m4.xlarge": "hvm",
- "m4.2xlarge": "hvm",
- "m4.4xlarge": "hvm",
- "m4.10xlarge": "hvm",
- "r3.large": "hvm",
- "r3.xlarge": "hvm",
- "r3.2xlarge": "hvm",
- "r3.4xlarge": "hvm",
- "r3.8xlarge": "hvm",
- "t1.micro": "pvm",
- "t2.micro": "hvm",
- "t2.small": "hvm",
- "t2.medium": "hvm",
- "t2.large": "hvm",
-}
-
-
-def get_tachyon_version(spark_version):
- return SPARK_TACHYON_MAP.get(spark_version, "")
-
-
-# Attempt to resolve an appropriate AMI given the architecture and region of the request.
-def get_spark_ami(opts):
- if opts.instance_type in EC2_INSTANCE_TYPES:
- instance_type = EC2_INSTANCE_TYPES[opts.instance_type]
- else:
- instance_type = "pvm"
- print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr)
-
- # URL prefix from which to fetch AMI information
- ami_prefix = "{r}/{b}/ami-list".format(
- r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1),
- b=opts.spark_ec2_git_branch)
-
- ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type)
- reader = codecs.getreader("ascii")
- try:
- ami = reader(urlopen(ami_path)).read().strip()
- except:
- print("Could not resolve AMI at: " + ami_path, file=stderr)
- sys.exit(1)
-
- print("Spark AMI: " + ami)
- return ami
-
-
-# Launch a cluster of the given name, by setting up its security groups,
-# and then starting new instances in them.
-# Returns a tuple of EC2 reservation objects for the master and slaves
-# Fails if there already instances running in the cluster's groups.
-def launch_cluster(conn, opts, cluster_name):
- if opts.identity_file is None:
- print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr)
- sys.exit(1)
-
- if opts.key_pair is None:
- print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr)
- sys.exit(1)
-
- user_data_content = None
- if opts.user_data:
- with open(opts.user_data) as user_data_file:
- user_data_content = user_data_file.read()
-
- print("Setting up security groups...")
- master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id)
- slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id)
- authorized_address = opts.authorized_address
- if master_group.rules == []: # Group was just now created
- if opts.vpc_id is None:
- master_group.authorize(src_group=master_group)
- master_group.authorize(src_group=slave_group)
- else:
- master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1,
- src_group=master_group)
- master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535,
- src_group=master_group)
- master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535,
- src_group=master_group)
- master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1,
- src_group=slave_group)
- master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535,
- src_group=slave_group)
- master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535,
- src_group=slave_group)
- master_group.authorize('tcp', 22, 22, authorized_address)
- master_group.authorize('tcp', 8080, 8081, authorized_address)
- master_group.authorize('tcp', 18080, 18080, authorized_address)
- master_group.authorize('tcp', 19999, 19999, authorized_address)
- master_group.authorize('tcp', 50030, 50030, authorized_address)
- master_group.authorize('tcp', 50070, 50070, authorized_address)
- master_group.authorize('tcp', 60070, 60070, authorized_address)
- master_group.authorize('tcp', 4040, 4045, authorized_address)
- # Rstudio (GUI for R) needs port 8787 for web access
- master_group.authorize('tcp', 8787, 8787, authorized_address)
- # HDFS NFS gateway requires 111,2049,4242 for tcp & udp
- master_group.authorize('tcp', 111, 111, authorized_address)
- master_group.authorize('udp', 111, 111, authorized_address)
- master_group.authorize('tcp', 2049, 2049, authorized_address)
- master_group.authorize('udp', 2049, 2049, authorized_address)
- master_group.authorize('tcp', 4242, 4242, authorized_address)
- master_group.authorize('udp', 4242, 4242, authorized_address)
- # RM in YARN mode uses 8088
- master_group.authorize('tcp', 8088, 8088, authorized_address)
- if opts.ganglia:
- master_group.authorize('tcp', 5080, 5080, authorized_address)
- if slave_group.rules == []: # Group was just now created
- if opts.vpc_id is None:
- slave_group.authorize(src_group=master_group)
- slave_group.authorize(src_group=slave_group)
- else:
- slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1,
- src_group=master_group)
- slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535,
- src_group=master_group)
- slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535,
- src_group=master_group)
- slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1,
- src_group=slave_group)
- slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535,
- src_group=slave_group)
- slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535,
- src_group=slave_group)
- slave_group.authorize('tcp', 22, 22, authorized_address)
- slave_group.authorize('tcp', 8080, 8081, authorized_address)
- slave_group.authorize('tcp', 50060, 50060, authorized_address)
- slave_group.authorize('tcp', 50075, 50075, authorized_address)
- slave_group.authorize('tcp', 60060, 60060, authorized_address)
- slave_group.authorize('tcp', 60075, 60075, authorized_address)
-
- # Check if instances are already running in our groups
- existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
- die_on_error=False)
- if existing_slaves or (existing_masters and not opts.use_existing_master):
- print("ERROR: There are already instances running in group %s or %s" %
- (master_group.name, slave_group.name), file=stderr)
- sys.exit(1)
-
- # Figure out Spark AMI
- if opts.ami is None:
- opts.ami = get_spark_ami(opts)
-
- # we use group ids to work around https://github.com/boto/boto/issues/350
- additional_group_ids = []
- if opts.additional_security_group:
- additional_group_ids = [sg.id
- for sg in conn.get_all_security_groups()
- if opts.additional_security_group in (sg.name, sg.id)]
- print("Launching instances...")
-
- try:
- image = conn.get_all_images(image_ids=[opts.ami])[0]
- except:
- print("Could not find AMI " + opts.ami, file=stderr)
- sys.exit(1)
-
- # Create block device mapping so that we can add EBS volumes if asked to.
- # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz
- block_map = BlockDeviceMapping()
- if opts.ebs_vol_size > 0:
- for i in range(opts.ebs_vol_num):
- device = EBSBlockDeviceType()
- device.size = opts.ebs_vol_size
- device.volume_type = opts.ebs_vol_type
- device.delete_on_termination = True
- block_map["/dev/sd" + chr(ord('s') + i)] = device
-
- # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342).
- if opts.instance_type.startswith('m3.'):
- for i in range(get_num_disks(opts.instance_type)):
- dev = BlockDeviceType()
- dev.ephemeral_name = 'ephemeral%d' % i
- # The first ephemeral drive is /dev/sdb.
- name = '/dev/sd' + string.ascii_letters[i + 1]
- block_map[name] = dev
-
- # Launch slaves
- if opts.spot_price is not None:
- # Launch spot instances with the requested price
- print("Requesting %d slaves as spot instances with price $%.3f" %
- (opts.slaves, opts.spot_price))
- zones = get_zones(conn, opts)
- num_zones = len(zones)
- i = 0
- my_req_ids = []
- for zone in zones:
- num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
- slave_reqs = conn.request_spot_instances(
- price=opts.spot_price,
- image_id=opts.ami,
- launch_group="launch-group-%s" % cluster_name,
- placement=zone,
- count=num_slaves_this_zone,
- key_name=opts.key_pair,
- security_group_ids=[slave_group.id] + additional_group_ids,
- instance_type=opts.instance_type,
- block_device_map=block_map,
- subnet_id=opts.subnet_id,
- placement_group=opts.placement_group,
- user_data=user_data_content,
- instance_profile_name=opts.instance_profile_name)
- my_req_ids += [req.id for req in slave_reqs]
- i += 1
-
- print("Waiting for spot instances to be granted...")
- try:
- while True:
- time.sleep(10)
- reqs = conn.get_all_spot_instance_requests()
- id_to_req = {}
- for r in reqs:
- id_to_req[r.id] = r
- active_instance_ids = []
- for i in my_req_ids:
- if i in id_to_req and id_to_req[i].state == "active":
- active_instance_ids.append(id_to_req[i].instance_id)
- if len(active_instance_ids) == opts.slaves:
- print("All %d slaves granted" % opts.slaves)
- reservations = conn.get_all_reservations(active_instance_ids)
- slave_nodes = []
- for r in reservations:
- slave_nodes += r.instances
- break
- else:
- print("%d of %d slaves granted, waiting longer" % (
- len(active_instance_ids), opts.slaves))
- except:
- print("Canceling spot instance requests")
- conn.cancel_spot_instance_requests(my_req_ids)
- # Log a warning if any of these requests actually launched instances:
- (master_nodes, slave_nodes) = get_existing_cluster(
- conn, opts, cluster_name, die_on_error=False)
- running = len(master_nodes) + len(slave_nodes)
- if running:
- print(("WARNING: %d instances are still running" % running), file=stderr)
- sys.exit(0)
- else:
- # Launch non-spot instances
- zones = get_zones(conn, opts)
- num_zones = len(zones)
- i = 0
- slave_nodes = []
- for zone in zones:
- num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
- if num_slaves_this_zone > 0:
- slave_res = image.run(
- key_name=opts.key_pair,
- security_group_ids=[slave_group.id] + additional_group_ids,
- instance_type=opts.instance_type,
- placement=zone,
- min_count=num_slaves_this_zone,
- max_count=num_slaves_this_zone,
- block_device_map=block_map,
- subnet_id=opts.subnet_id,
- placement_group=opts.placement_group,
- user_data=user_data_content,
- instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior,
- instance_profile_name=opts.instance_profile_name)
- slave_nodes += slave_res.instances
- print("Launched {s} slave{plural_s} in {z}, regid = {r}".format(
- s=num_slaves_this_zone,
- plural_s=('' if num_slaves_this_zone == 1 else 's'),
- z=zone,
- r=slave_res.id))
- i += 1
-
- # Launch or resume masters
- if existing_masters:
- print("Starting master...")
- for inst in existing_masters:
- if inst.state not in ["shutting-down", "terminated"]:
- inst.start()
- master_nodes = existing_masters
- else:
- master_type = opts.master_instance_type
- if master_type == "":
- master_type = opts.instance_type
- if opts.zone == 'all':
- opts.zone = random.choice(conn.get_all_zones()).name
- master_res = image.run(
- key_name=opts.key_pair,
- security_group_ids=[master_group.id] + additional_group_ids,
- instance_type=master_type,
- placement=opts.zone,
- min_count=1,
- max_count=1,
- block_device_map=block_map,
- subnet_id=opts.subnet_id,
- placement_group=opts.placement_group,
- user_data=user_data_content,
- instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior,
- instance_profile_name=opts.instance_profile_name)
-
- master_nodes = master_res.instances
- print("Launched master in %s, regid = %s" % (zone, master_res.id))
-
- # This wait time corresponds to SPARK-4983
- print("Waiting for AWS to propagate instance metadata...")
- time.sleep(15)
-
- # Give the instances descriptive names and set additional tags
- additional_tags = {}
- if opts.additional_tags.strip():
- additional_tags = dict(
- map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',')
- )
-
- for master in master_nodes:
- master.add_tags(
- dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
- )
-
- for slave in slave_nodes:
- slave.add_tags(
- dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
- )
-
- # Return all the instances
- return (master_nodes, slave_nodes)
-
-
-def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
- """
- Get the EC2 instances in an existing cluster if available.
- Returns a tuple of lists of EC2 instance objects for the masters and slaves.
- """
- print("Searching for existing cluster {c} in region {r}...".format(
- c=cluster_name, r=opts.region))
-
- def get_instances(group_names):
- """
- Get all non-terminated instances that belong to any of the provided security groups.
-
- EC2 reservation filters and instance states are documented here:
- http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options
- """
- reservations = conn.get_all_reservations(
- filters={"instance.group-name": group_names})
- instances = itertools.chain.from_iterable(r.instances for r in reservations)
- return [i for i in instances if i.state not in ["shutting-down", "terminated"]]
-
- master_instances = get_instances([cluster_name + "-master"])
- slave_instances = get_instances([cluster_name + "-slaves"])
-
- if any((master_instances, slave_instances)):
- print("Found {m} master{plural_m}, {s} slave{plural_s}.".format(
- m=len(master_instances),
- plural_m=('' if len(master_instances) == 1 else 's'),
- s=len(slave_instances),
- plural_s=('' if len(slave_instances) == 1 else 's')))
-
- if not master_instances and die_on_error:
- print("ERROR: Could not find a master for cluster {c} in region {r}.".format(
- c=cluster_name, r=opts.region), file=sys.stderr)
- sys.exit(1)
-
- return (master_instances, slave_instances)
-
-
-# Deploy configuration files and run setup scripts on a newly launched
-# or started EC2 cluster.
-def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
- master = get_dns_name(master_nodes[0], opts.private_ips)
- if deploy_ssh_key:
- print("Generating cluster's SSH key on master...")
- key_setup = """
- [ -f ~/.ssh/id_rsa ] ||
- (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
- cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys)
- """
- ssh(master, opts, key_setup)
- dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
- print("Transferring cluster's SSH key to slaves...")
- for slave in slave_nodes:
- slave_address = get_dns_name(slave, opts.private_ips)
- print(slave_address)
- ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar)
-
- modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs',
- 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio']
-
- if opts.hadoop_major_version == "1":
- modules = list(filter(lambda x: x != "mapreduce", modules))
-
- if opts.ganglia:
- modules.append('ganglia')
-
- # Clear SPARK_WORKER_INSTANCES if running on YARN
- if opts.hadoop_major_version == "yarn":
- opts.worker_instances = ""
-
- # NOTE: We should clone the repository before running deploy_files to
- # prevent ec2-variables.sh from being overwritten
- print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
- r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch))
- ssh(
- host=master,
- opts=opts,
- command="rm -rf spark-ec2"
- + " && "
- + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo,
- b=opts.spark_ec2_git_branch)
- )
-
- print("Deploying files to master...")
- deploy_files(
- conn=conn,
- root_dir=SPARK_EC2_DIR + "/" + "deploy.generic",
- opts=opts,
- master_nodes=master_nodes,
- slave_nodes=slave_nodes,
- modules=modules
- )
-
- if opts.deploy_root_dir is not None:
- print("Deploying {s} to master...".format(s=opts.deploy_root_dir))
- deploy_user_files(
- root_dir=opts.deploy_root_dir,
- opts=opts,
- master_nodes=master_nodes
- )
-
- print("Running setup on master...")
- setup_spark_cluster(master, opts)
- print("Done!")
-
-
-def setup_spark_cluster(master, opts):
- ssh(master, opts, "chmod u+x spark-ec2/setup.sh")
- ssh(master, opts, "spark-ec2/setup.sh")
- print("Spark standalone cluster started at http://%s:8080" % master)
-
- if opts.ganglia:
- print("Ganglia started at http://%s:5080/ganglia" % master)
-
-
-def is_ssh_available(host, opts, print_ssh_output=True):
- """
- Check if SSH is available on a host.
- """
- s = subprocess.Popen(
- ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3',
- '%s@%s' % (opts.user, host), stringify_command('true')],
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order
- )
- cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout
-
- if s.returncode != 0 and print_ssh_output:
- # extra leading newline is for spacing in wait_for_cluster_state()
- print(textwrap.dedent("""\n
- Warning: SSH connection error. (This could be temporary.)
- Host: {h}
- SSH return code: {r}
- SSH output: {o}
- """).format(
- h=host,
- r=s.returncode,
- o=cmd_output.strip()
- ))
-
- return s.returncode == 0
-
-
-def is_cluster_ssh_available(cluster_instances, opts):
- """
- Check if SSH is available on all the instances in a cluster.
- """
- for i in cluster_instances:
- dns_name = get_dns_name(i, opts.private_ips)
- if not is_ssh_available(host=dns_name, opts=opts):
- return False
- else:
- return True
-
-
-def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state):
- """
- Wait for all the instances in the cluster to reach a designated state.
-
- cluster_instances: a list of boto.ec2.instance.Instance
- cluster_state: a string representing the desired state of all the instances in the cluster
- value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as
- 'running', 'terminated', etc.
- (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250)
- """
- sys.stdout.write(
- "Waiting for cluster to enter '{s}' state.".format(s=cluster_state)
- )
- sys.stdout.flush()
-
- start_time = datetime.now()
- num_attempts = 0
-
- while True:
- time.sleep(5 * num_attempts) # seconds
-
- for i in cluster_instances:
- i.update()
-
- max_batch = 100
- statuses = []
- for j in xrange(0, len(cluster_instances), max_batch):
- batch = [i.id for i in cluster_instances[j:j + max_batch]]
- statuses.extend(conn.get_all_instance_status(instance_ids=batch))
-
- if cluster_state == 'ssh-ready':
- if all(i.state == 'running' for i in cluster_instances) and \
- all(s.system_status.status == 'ok' for s in statuses) and \
- all(s.instance_status.status == 'ok' for s in statuses) and \
- is_cluster_ssh_available(cluster_instances, opts):
- break
- else:
- if all(i.state == cluster_state for i in cluster_instances):
- break
-
- num_attempts += 1
-
- sys.stdout.write(".")
- sys.stdout.flush()
-
- sys.stdout.write("\n")
-
- end_time = datetime.now()
- print("Cluster is now in '{s}' state. Waited {t} seconds.".format(
- s=cluster_state,
- t=(end_time - start_time).seconds
- ))
-
-
-# Get number of local disks available for a given EC2 instance type.
-def get_num_disks(instance_type):
- # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html
- # Last Updated: 2015-06-19
- # For easy maintainability, please keep this manually-inputted dictionary sorted by key.
- disks_by_instance = {
- "c1.medium": 1,
- "c1.xlarge": 4,
- "c3.large": 2,
- "c3.xlarge": 2,
- "c3.2xlarge": 2,
- "c3.4xlarge": 2,
- "c3.8xlarge": 2,
- "c4.large": 0,
- "c4.xlarge": 0,
- "c4.2xlarge": 0,
- "c4.4xlarge": 0,
- "c4.8xlarge": 0,
- "cc1.4xlarge": 2,
- "cc2.8xlarge": 4,
- "cg1.4xlarge": 2,
- "cr1.8xlarge": 2,
- "d2.xlarge": 3,
- "d2.2xlarge": 6,
- "d2.4xlarge": 12,
- "d2.8xlarge": 24,
- "g2.2xlarge": 1,
- "g2.8xlarge": 2,
- "hi1.4xlarge": 2,
- "hs1.8xlarge": 24,
- "i2.xlarge": 1,
- "i2.2xlarge": 2,
- "i2.4xlarge": 4,
- "i2.8xlarge": 8,
- "m1.small": 1,
- "m1.medium": 1,
- "m1.large": 2,
- "m1.xlarge": 4,
- "m2.xlarge": 1,
- "m2.2xlarge": 1,
- "m2.4xlarge": 2,
- "m3.medium": 1,
- "m3.large": 1,
- "m3.xlarge": 2,
- "m3.2xlarge": 2,
- "m4.large": 0,
- "m4.xlarge": 0,
- "m4.2xlarge": 0,
- "m4.4xlarge": 0,
- "m4.10xlarge": 0,
- "r3.large": 1,
- "r3.xlarge": 1,
- "r3.2xlarge": 1,
- "r3.4xlarge": 1,
- "r3.8xlarge": 2,
- "t1.micro": 0,
- "t2.micro": 0,
- "t2.small": 0,
- "t2.medium": 0,
- "t2.large": 0,
- }
- if instance_type in disks_by_instance:
- return disks_by_instance[instance_type]
- else:
- print("WARNING: Don't know number of disks on instance type %s; assuming 1"
- % instance_type, file=stderr)
- return 1
-
-
-# Deploy the configuration file templates in a given local directory to
-# a cluster, filling in any template parameters with information about the
-# cluster (e.g. lists of masters and slaves). Files are only deployed to
-# the first master instance in the cluster, and we expect the setup
-# script to be run on that instance to copy them to other nodes.
-#
-# root_dir should be an absolute path to the directory with the files we want to deploy.
-def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
- active_master = get_dns_name(master_nodes[0], opts.private_ips)
-
- num_disks = get_num_disks(opts.instance_type)
- hdfs_data_dirs = "/mnt/ephemeral-hdfs/data"
- mapred_local_dirs = "/mnt/hadoop/mrlocal"
- spark_local_dirs = "/mnt/spark"
- if num_disks > 1:
- for i in range(2, num_disks + 1):
- hdfs_data_dirs += ",/mnt%d/ephemeral-hdfs/data" % i
- mapred_local_dirs += ",/mnt%d/hadoop/mrlocal" % i
- spark_local_dirs += ",/mnt%d/spark" % i
-
- cluster_url = "%s:7077" % active_master
-
- if "." in opts.spark_version:
- # Pre-built Spark deploy
- spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo)
- tachyon_v = get_tachyon_version(spark_v)
- else:
- # Spark-only custom deploy
- spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version)
- tachyon_v = ""
- print("Deploying Spark via git hash; Tachyon won't be set up")
- modules = filter(lambda x: x != "tachyon", modules)
-
- master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes]
- slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes]
- worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else ""
- template_vars = {
- "master_list": '\n'.join(master_addresses),
- "active_master": active_master,
- "slave_list": '\n'.join(slave_addresses),
- "cluster_url": cluster_url,
- "hdfs_data_dirs": hdfs_data_dirs,
- "mapred_local_dirs": mapred_local_dirs,
- "spark_local_dirs": spark_local_dirs,
- "swap": str(opts.swap),
- "modules": '\n'.join(modules),
- "spark_version": spark_v,
- "tachyon_version": tachyon_v,
- "hadoop_major_version": opts.hadoop_major_version,
- "spark_worker_instances": worker_instances_str,
- "spark_master_opts": opts.master_opts
- }
-
- if opts.copy_aws_credentials:
- template_vars["aws_access_key_id"] = conn.aws_access_key_id
- template_vars["aws_secret_access_key"] = conn.aws_secret_access_key
- else:
- template_vars["aws_access_key_id"] = ""
- template_vars["aws_secret_access_key"] = ""
-
- # Create a temp directory in which we will place all the files to be
- # deployed after we substitue template parameters in them
- tmp_dir = tempfile.mkdtemp()
- for path, dirs, files in os.walk(root_dir):
- if path.find(".svn") == -1:
- dest_dir = os.path.join('/', path[len(root_dir):])
- local_dir = tmp_dir + dest_dir
- if not os.path.exists(local_dir):
- os.makedirs(local_dir)
- for filename in files:
- if filename[0] not in '#.~' and filename[-1] != '~':
- dest_file = os.path.join(dest_dir, filename)
- local_file = tmp_dir + dest_file
- with open(os.path.join(path, filename)) as src:
- with open(local_file, "w") as dest:
- text = src.read()
- for key in template_vars:
- text = text.replace("{{" + key + "}}", template_vars[key])
- dest.write(text)
- dest.close()
- # rsync the whole directory over to the master machine
- command = [
- 'rsync', '-rv',
- '-e', stringify_command(ssh_command(opts)),
- "%s/" % tmp_dir,
- "%s@%s:/" % (opts.user, active_master)
- ]
- subprocess.check_call(command)
- # Remove the temp directory we created above
- shutil.rmtree(tmp_dir)
-
-
-# Deploy a given local directory to a cluster, WITHOUT parameter substitution.
-# Note that unlike deploy_files, this works for binary files.
-# Also, it is up to the user to add (or not) the trailing slash in root_dir.
-# Files are only deployed to the first master instance in the cluster.
-#
-# root_dir should be an absolute path.
-def deploy_user_files(root_dir, opts, master_nodes):
- active_master = get_dns_name(master_nodes[0], opts.private_ips)
- command = [
- 'rsync', '-rv',
- '-e', stringify_command(ssh_command(opts)),
- "%s" % root_dir,
- "%s@%s:/" % (opts.user, active_master)
- ]
- subprocess.check_call(command)
-
-
-def stringify_command(parts):
- if isinstance(parts, str):
- return parts
- else:
- return ' '.join(map(pipes.quote, parts))
-
-
-def ssh_args(opts):
- parts = ['-o', 'StrictHostKeyChecking=no']
- parts += ['-o', 'UserKnownHostsFile=/dev/null']
- if opts.identity_file is not None:
- parts += ['-i', opts.identity_file]
- return parts
-
-
-def ssh_command(opts):
- return ['ssh'] + ssh_args(opts)
-
-
-# Run a command on a host through ssh, retrying up to five times
-# and then throwing an exception if ssh continues to fail.
-def ssh(host, opts, command):
- tries = 0
- while True:
- try:
- return subprocess.check_call(
- ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host),
- stringify_command(command)])
- except subprocess.CalledProcessError as e:
- if tries > 5:
- # If this was an ssh failure, provide the user with hints.
- if e.returncode == 255:
- raise UsageError(
- "Failed to SSH to remote host {0}.\n"
- "Please check that you have provided the correct --identity-file and "
- "--key-pair parameters and try again.".format(host))
- else:
- raise e
- print("Error executing remote command, retrying after 30 seconds: {0}".format(e),
- file=stderr)
- time.sleep(30)
- tries = tries + 1
-
-
-# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990)
-def _check_output(*popenargs, **kwargs):
- if 'stdout' in kwargs:
- raise ValueError('stdout argument not allowed, it will be overridden.')
- process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
- output, unused_err = process.communicate()
- retcode = process.poll()
- if retcode:
- cmd = kwargs.get("args")
- if cmd is None:
- cmd = popenargs[0]
- raise subprocess.CalledProcessError(retcode, cmd, output=output)
- return output
-
-
-def ssh_read(host, opts, command):
- return _check_output(
- ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)])
-
-
-def ssh_write(host, opts, command, arguments):
- tries = 0
- while True:
- proc = subprocess.Popen(
- ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)],
- stdin=subprocess.PIPE)
- proc.stdin.write(arguments)
- proc.stdin.close()
- status = proc.wait()
- if status == 0:
- break
- elif tries > 5:
- raise RuntimeError("ssh_write failed with error %s" % proc.returncode)
- else:
- print("Error {0} while executing remote command, retrying after 30 seconds".
- format(status), file=stderr)
- time.sleep(30)
- tries = tries + 1
-
-
-# Gets a list of zones to launch instances in
-def get_zones(conn, opts):
- if opts.zone == 'all':
- zones = [z.name for z in conn.get_all_zones()]
- else:
- zones = [opts.zone]
- return zones
-
-
-# Gets the number of items in a partition
-def get_partition(total, num_partitions, current_partitions):
- num_slaves_this_zone = total // num_partitions
- if (total % num_partitions) - current_partitions > 0:
- num_slaves_this_zone += 1
- return num_slaves_this_zone
-
-
-# Gets the IP address, taking into account the --private-ips flag
-def get_ip_address(instance, private_ips=False):
- ip = instance.ip_address if not private_ips else \
- instance.private_ip_address
- return ip
-
-
-# Gets the DNS name, taking into account the --private-ips flag
-def get_dns_name(instance, private_ips=False):
- dns = instance.public_dns_name if not private_ips else \
- instance.private_ip_address
- if not dns:
- raise UsageError("Failed to determine hostname of {0}.\n"
- "Please check that you provided --private-ips if "
- "necessary".format(instance))
- return dns
-
-
-def real_main():
- (opts, action, cluster_name) = parse_args()
-
- # Input parameter validation
- get_validate_spark_version(opts.spark_version, opts.spark_git_repo)
-
- if opts.wait is not None:
- # NOTE: DeprecationWarnings are silent in 2.7+ by default.
- # To show them, run Python with the -Wdefault switch.
- # See: https://docs.python.org/3.5/whatsnew/2.7.html
- warnings.warn(
- "This option is deprecated and has no effect. "
- "spark-ec2 automatically waits as long as necessary for clusters to start up.",
- DeprecationWarning
- )
-
- if opts.identity_file is not None:
- if not os.path.exists(opts.identity_file):
- print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file),
- file=stderr)
- sys.exit(1)
-
- file_mode = os.stat(opts.identity_file).st_mode
- if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00':
- print("ERROR: The identity file must be accessible only by you.", file=stderr)
- print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file),
- file=stderr)
- sys.exit(1)
-
- if opts.instance_type not in EC2_INSTANCE_TYPES:
- print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
- t=opts.instance_type), file=stderr)
-
- if opts.master_instance_type != "":
- if opts.master_instance_type not in EC2_INSTANCE_TYPES:
- print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
- t=opts.master_instance_type), file=stderr)
- # Since we try instance types even if we can't resolve them, we check if they resolve first
- # and, if they do, see if they resolve to the same virtualization type.
- if opts.instance_type in EC2_INSTANCE_TYPES and \
- opts.master_instance_type in EC2_INSTANCE_TYPES:
- if EC2_INSTANCE_TYPES[opts.instance_type] != \
- EC2_INSTANCE_TYPES[opts.master_instance_type]:
- print("Error: spark-ec2 currently does not support having a master and slaves "
- "with different AMI virtualization types.", file=stderr)
- print("master instance virtualization type: {t}".format(
- t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr)
- print("slave instance virtualization type: {t}".format(
- t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr)
- sys.exit(1)
-
- if opts.ebs_vol_num > 8:
- print("ebs-vol-num cannot be greater than 8", file=stderr)
- sys.exit(1)
-
- # Prevent breaking ami_prefix (/, .git and startswith checks)
- # Prevent forks with non spark-ec2 names for now.
- if opts.spark_ec2_git_repo.endswith("/") or \
- opts.spark_ec2_git_repo.endswith(".git") or \
- not opts.spark_ec2_git_repo.startswith("https://github.com") or \
- not opts.spark_ec2_git_repo.endswith("spark-ec2"):
- print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. "
- "Furthermore, we currently only support forks named spark-ec2.", file=stderr)
- sys.exit(1)
-
- if not (opts.deploy_root_dir is None or
- (os.path.isabs(opts.deploy_root_dir) and
- os.path.isdir(opts.deploy_root_dir) and
- os.path.exists(opts.deploy_root_dir))):
- print("--deploy-root-dir must be an absolute path to a directory that exists "
- "on the local file system", file=stderr)
- sys.exit(1)
-
- try:
- if opts.profile is None:
- conn = ec2.connect_to_region(opts.region)
- else:
- conn = ec2.connect_to_region(opts.region, profile_name=opts.profile)
- except Exception as e:
- print((e), file=stderr)
- sys.exit(1)
-
- # Select an AZ at random if it was not specified.
- if opts.zone == "":
- opts.zone = random.choice(conn.get_all_zones()).name
-
- if action == "launch":
- if opts.slaves <= 0:
- print("ERROR: You have to start at least 1 slave", file=sys.stderr)
- sys.exit(1)
- if opts.resume:
- (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
- else:
- (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name)
- wait_for_cluster_state(
- conn=conn,
- opts=opts,
- cluster_instances=(master_nodes + slave_nodes),
- cluster_state='ssh-ready'
- )
- setup_cluster(conn, master_nodes, slave_nodes, opts, True)
-
- elif action == "destroy":
- (master_nodes, slave_nodes) = get_existing_cluster(
- conn, opts, cluster_name, die_on_error=False)
-
- if any(master_nodes + slave_nodes):
- print("The following instances will be terminated:")
- for inst in master_nodes + slave_nodes:
- print("> %s" % get_dns_name(inst, opts.private_ips))
- print("ALL DATA ON ALL NODES WILL BE LOST!!")
-
- msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name)
- response = raw_input(msg)
- if response == "y":
- print("Terminating master...")
- for inst in master_nodes:
- inst.terminate()
- print("Terminating slaves...")
- for inst in slave_nodes:
- inst.terminate()
-
- # Delete security groups as well
- if opts.delete_groups:
- group_names = [cluster_name + "-master", cluster_name + "-slaves"]
- wait_for_cluster_state(
- conn=conn,
- opts=opts,
- cluster_instances=(master_nodes + slave_nodes),
- cluster_state='terminated'
- )
- print("Deleting security groups (this will take some time)...")
- attempt = 1
- while attempt <= 3:
- print("Attempt %d" % attempt)
- groups = [g for g in conn.get_all_security_groups() if g.name in group_names]
- success = True
- # Delete individual rules in all groups before deleting groups to
- # remove dependencies between them
- for group in groups:
- print("Deleting rules in security group " + group.name)
- for rule in group.rules:
- for grant in rule.grants:
- success &= group.revoke(ip_protocol=rule.ip_protocol,
- from_port=rule.from_port,
- to_port=rule.to_port,
- src_group=grant)
-
- # Sleep for AWS eventual-consistency to catch up, and for instances
- # to terminate
- time.sleep(30) # Yes, it does have to be this long :-(
- for group in groups:
- try:
- # It is needed to use group_id to make it work with VPC
- conn.delete_security_group(group_id=group.id)
- print("Deleted security group %s" % group.name)
- except boto.exception.EC2ResponseError:
- success = False
- print("Failed to delete security group %s" % group.name)
-
- # Unfortunately, group.revoke() returns True even if a rule was not
- # deleted, so this needs to be rerun if something fails
- if success:
- break
-
- attempt += 1
-
- if not success:
- print("Failed to delete all security groups after 3 tries.")
- print("Try re-running in a few minutes.")
-
- elif action == "login":
- (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
- if not master_nodes[0].public_dns_name and not opts.private_ips:
- print("Master has no public DNS name. Maybe you meant to specify --private-ips?")
- else:
- master = get_dns_name(master_nodes[0], opts.private_ips)
- print("Logging into master " + master + "...")
- proxy_opt = []
- if opts.proxy_port is not None:
- proxy_opt = ['-D', opts.proxy_port]
- subprocess.check_call(
- ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)])
-
- elif action == "reboot-slaves":
- response = raw_input(
- "Are you sure you want to reboot the cluster " +
- cluster_name + " slaves?\n" +
- "Reboot cluster slaves " + cluster_name + " (y/N): ")
- if response == "y":
- (master_nodes, slave_nodes) = get_existing_cluster(
- conn, opts, cluster_name, die_on_error=False)
- print("Rebooting slaves...")
- for inst in slave_nodes:
- if inst.state not in ["shutting-down", "terminated"]:
- print("Rebooting " + inst.id)
- inst.reboot()
-
- elif action == "get-master":
- (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
- if not master_nodes[0].public_dns_name and not opts.private_ips:
- print("Master has no public DNS name. Maybe you meant to specify --private-ips?")
- else:
- print(get_dns_name(master_nodes[0], opts.private_ips))
-
- elif action == "stop":
- response = raw_input(
- "Are you sure you want to stop the cluster " +
- cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " +
- "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" +
- "AMAZON EBS IF IT IS EBS-BACKED!!\n" +
- "All data on spot-instance slaves will be lost.\n" +
- "Stop cluster " + cluster_name + " (y/N): ")
- if response == "y":
- (master_nodes, slave_nodes) = get_existing_cluster(
- conn, opts, cluster_name, die_on_error=False)
- print("Stopping master...")
- for inst in master_nodes:
- if inst.state not in ["shutting-down", "terminated"]:
- inst.stop()
- print("Stopping slaves...")
- for inst in slave_nodes:
- if inst.state not in ["shutting-down", "terminated"]:
- if inst.spot_instance_request_id:
- inst.terminate()
- else:
- inst.stop()
-
- elif action == "start":
- (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
- print("Starting slaves...")
- for inst in slave_nodes:
- if inst.state not in ["shutting-down", "terminated"]:
- inst.start()
- print("Starting master...")
- for inst in master_nodes:
- if inst.state not in ["shutting-down", "terminated"]:
- inst.start()
- wait_for_cluster_state(
- conn=conn,
- opts=opts,
- cluster_instances=(master_nodes + slave_nodes),
- cluster_state='ssh-ready'
- )
-
- # Determine types of running instances
- existing_master_type = master_nodes[0].instance_type
- existing_slave_type = slave_nodes[0].instance_type
- # Setting opts.master_instance_type to the empty string indicates we
- # have the same instance type for the master and the slaves
- if existing_master_type == existing_slave_type:
- existing_master_type = ""
- opts.master_instance_type = existing_master_type
- opts.instance_type = existing_slave_type
-
- setup_cluster(conn, master_nodes, slave_nodes, opts, False)
-
- else:
- print("Invalid action: %s" % action, file=stderr)
- sys.exit(1)
-
-
-def main():
- try:
- real_main()
- except UsageError as e:
- print("\nError:\n", e, file=stderr)
- sys.exit(1)
-
-
-if __name__ == "__main__":
- logging.basicConfig()
- main()
diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py
index f6b0ecb02c100..b6c2916254056 100755
--- a/examples/src/main/python/sort.py
+++ b/examples/src/main/python/sort.py
@@ -30,7 +30,7 @@
lines = sc.textFile(sys.argv[1], 1)
sortedCount = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (int(x), 1)) \
- .sortByKey(lambda x: x)
+ .sortByKey()
# This is just a demo on how to bring all the sorted data back to a single node.
# In reality, we wouldn't want to collect all the data to the driver node.
output = sortedCount.collect()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala
index 3834ea807acbf..c4336639d7c0b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala
@@ -25,7 +25,7 @@ import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegression
object IsotonicRegressionExample {
- def main(args: Array[String]) : Unit = {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("IsotonicRegressionExample")
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala
index 8bae1b9d1832d..0187ad603a654 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
object NaiveBayesExample {
- def main(args: Array[String]) : Unit = {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("NaiveBayesExample")
val sc = new SparkContext(conf)
// $example on$
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala
index ace16ff1ea225..add634c957b40 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SQLContext
object RegressionMetricsExample {
- def main(args: Array[String]) : Unit = {
+ def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("RegressionMetricsExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index c4e18d92eefa9..d7885d7cc1ae1 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -385,7 +385,7 @@ object KafkaCluster {
val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp =>
val hpa = hp.split(":")
if (hpa.size == 1) {
- throw new SparkException(s"Broker not the in correct format of : [$brokers]")
+ throw new SparkException(s"Broker not in the correct format of : [$brokers]")
}
(hpa(0), hpa(1).toInt)
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
index 603be22818206..4eb155645867b 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
@@ -156,7 +156,7 @@ class KafkaRDD[
var requestOffset = part.fromOffset
var iter: Iterator[MessageAndOffset] = null
- // The idea is to use the provided preferred host, except on task retry atttempts,
+ // The idea is to use the provided preferred host, except on task retry attempts,
// to minimize number of kafka metadata requests
private def connectLeader: SimpleConsumer = {
if (context.attemptNumber > 0) {
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index b3ba72a0087ad..d3a2bf5825b08 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -51,7 +51,7 @@
org.eclipse.paho
org.eclipse.paho.client.mqttv3
- 1.0.1
+ 1.0.2
org.scalacheck
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
index de749626ec09c..6a73bc0e30d05 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import scala.util.Random
-import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials}
+import com.amazonaws.auth.{BasicAWSCredentials, DefaultAWSCredentialsProviderChain}
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.AmazonKinesisClient
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
index 3321c7527edb4..5223c81a8e0e0 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
@@ -24,10 +24,10 @@ import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.streaming.{Duration, StreamingContext, Time}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
-import org.apache.spark.streaming.{Duration, StreamingContext, Time}
private[kinesis] class KinesisInputDStream[T: ClassTag](
_ssc: StreamingContext,
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index abb9b6cd32f1c..48ee2a959786b 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import scala.util.control.NonFatal
import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain}
-import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory}
+import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory}
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker}
import com.amazonaws.services.kinesis.model.Record
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
index 2de6195716e5c..15ac588b82587 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
@@ -24,9 +24,9 @@ import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Duration, StreamingContext}
import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.apache.spark.streaming.{Duration, StreamingContext}
object KinesisUtils {
/**
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
index d85b4cda8ce98..e6f504c4e54dd 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.streaming.kinesis
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
extends KinesisFunSuite with BeforeAndAfterAll {
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
index 645e64a0bc3a0..e1499a8220991 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.kinesis
-import java.util.concurrent.{TimeoutException, ExecutorService}
+import java.util.concurrent.{ExecutorService, TimeoutException}
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
@@ -28,7 +28,7 @@ import org.mockito.Matchers._
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
-import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach}
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.Eventually._
import org.scalatest.mock.MockitoSugar
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index e5c70db554a27..fd15b6ccdc889 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -27,8 +27,8 @@ import com.amazonaws.services.kinesis.model.Record
import org.mockito.Matchers._
import org.mockito.Matchers.{eq => meq}
import org.mockito.Mockito._
-import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfter, Matchers}
+import org.scalatest.mock.MockitoSugar
import org.apache.spark.streaming.{Duration, TestSuiteBase}
import org.apache.spark.util.Utils
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 78263f9dca65c..ee6a5f0390d04 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -25,10 +25,11 @@ import scala.util.Random
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
@@ -38,7 +39,6 @@ import org.apache.spark.streaming.kinesis.KinesisTestUtils._
import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
import org.apache.spark.util.Utils
-import org.apache.spark.{SparkConf, SparkContext}
abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite
with Eventually with BeforeAndAfter with BeforeAndAfterAll {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index fc36e12dd2aed..d048fb5d561f3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -21,7 +21,6 @@ import scala.reflect.ClassTag
import scala.util.Random
import org.apache.spark.SparkException
-import org.apache.spark.SparkContext._
import org.apache.spark.graphx.lib._
import org.apache.spark.rdd.RDD
@@ -379,7 +378,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
* @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergenceWithOptions]]
*/
def personalizedPageRank(src: VertexId, tol: Double,
- resetProb: Double = 0.15) : Graph[Double, Double] = {
+ resetProb: Double = 0.15): Graph[Double, Double] = {
PageRank.runUntilConvergenceWithOptions(graph, tol, resetProb, Some(src))
}
@@ -392,7 +391,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
* @see [[org.apache.spark.graphx.lib.PageRank$#runWithOptions]]
*/
def staticPersonalizedPageRank(src: VertexId, numIter: Int,
- resetProb: Double = 0.15) : Graph[Double, Double] = {
+ resetProb: Double = 0.15): Graph[Double, Double] = {
PageRank.runWithOptions(graph, numIter, resetProb, Some(src))
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
index f79f9c7ec448f..b4bec7cba5207 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
@@ -41,8 +41,8 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag](
* shipping level.
*/
def withEdges[VD2: ClassTag, ED2: ClassTag](
- edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = {
- new ReplicatedVertexView(edges_, hasSrcId, hasDstId)
+ _edges: EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = {
+ new ReplicatedVertexView(_edges, hasSrcId, hasDstId)
}
/**
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
index 3f203c4eca485..96d807f9f9ceb 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
@@ -102,8 +102,8 @@ class ShippableVertexPartition[VD: ClassTag](
extends VertexPartitionBase[VD] {
/** Return a new ShippableVertexPartition with the specified routing table. */
- def withRoutingTable(routingTable_ : RoutingTablePartition): ShippableVertexPartition[VD] = {
- new ShippableVertexPartition(index, values, mask, routingTable_)
+ def withRoutingTable(_routingTable: RoutingTablePartition): ShippableVertexPartition[VD] = {
+ new ShippableVertexPartition(index, values, mask, _routingTable)
}
/**
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
index f508b483a2f1b..7c680dcb99cd2 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util.collection.BitSet
* example, [[VertexPartition.VertexPartitionOpsConstructor]]).
*/
private[graphx] abstract class VertexPartitionBaseOps
- [VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor]
+ [VD: ClassTag, Self[X] <: VertexPartitionBase[X]: VertexPartitionBaseOpsConstructor]
(self: Self[VD])
extends Serializable with Logging {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 35b26c998e1d9..46faad2e68c50 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -138,7 +138,7 @@ object PageRank extends Logging {
// edge partitions.
prevRankGraph = rankGraph
val rPrb = if (personalized) {
- (src: VertexId , id: VertexId) => resetProb * delta(src, id)
+ (src: VertexId, id: VertexId) => resetProb * delta(src, id)
} else {
(src: VertexId, id: VertexId) => resetProb
}
diff --git a/make-distribution.sh b/make-distribution.sh
index a38fd8df17206..327659298e4d8 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -212,7 +212,6 @@ cp "$SPARK_HOME/README.md" "$DISTDIR"
cp -r "$SPARK_HOME/bin" "$DISTDIR"
cp -r "$SPARK_HOME/python" "$DISTDIR"
cp -r "$SPARK_HOME/sbin" "$DISTDIR"
-cp -r "$SPARK_HOME/ec2" "$DISTDIR"
# Copy SparkR if it exists
if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then
mkdir -p "$DISTDIR"/R/lib
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 08a51109d6c62..c41a611f1cc60 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -113,13 +113,13 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}
- val transformedDataset = model.transform(df).select(columns : _*)
+ val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset
.withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
val newColumns = origCols ++ List(col(tmpColName))
// switch out the intermediate column with the accumulator column
- updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName)
+ updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName)
}
if (handlePersistence) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index f9952434d2982..6cc9d025445c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -238,7 +238,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
override def transform(dataset: DataFrame): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
- dataset.select(columnsToKeep.map(dataset.col) : _*)
+ dataset.select(columnsToKeep.map(dataset.col): _*)
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 0b215659b3672..716bc63e00995 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -102,7 +102,7 @@ class VectorAssembler(override val uid: String)
}
}
- dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata))
+ dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 6e87302c7779b..d3376a7dff938 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -474,7 +474,7 @@ private[ml] object RandomForest extends Logging {
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
- val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+ val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
@@ -825,7 +825,7 @@ private[ml] object RandomForest extends Logging {
protected[tree] def findSplits(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata,
- seed : Long): Array[Array[Split]] = {
+ seed: Long): Array[Array[Split]] = {
logDebug("isMulticlass = " + metadata.isMulticlass)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 7443097492d82..7a651a37ac77e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
-import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* Parameters for Decision Tree-based algorithms.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 5c9bc62cb09bb..16bc45bcb627f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -177,7 +177,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
}
@Since("1.4.0")
- override def load(sc: SparkContext, path: String) : GaussianMixtureModel = {
+ override def load(sc: SparkContext, path: String): GaussianMixtureModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val k = (metadata \ "k").extract[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 5273ed4d76650..1250bc1a07cb4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -17,8 +17,8 @@
package org.apache.spark.mllib.fpm
-import java.lang.{Iterable => JavaIterable}
import java.{util => ju}
+import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -29,16 +29,15 @@ import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}
-import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
+import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkContext, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.mllib.fpm.FPGrowth._
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -134,7 +133,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
loadImpl(freqItemsets, sample)
}
- def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = {
+ def loadImpl[Item: ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = {
val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x =>
val items = x.getAs[Seq[Item]](0).toArray
val freq = x.getLong(1)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index d7a74db0b1fd8..b08da4fb55034 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -279,7 +279,7 @@ class DenseMatrix @Since("1.3.0") (
}
override def hashCode: Int = {
- com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray)
+ com.google.common.base.Objects.hashCode(numRows: Integer, numCols: Integer, toArray)
}
private[mllib] def toBreeze: BM[Double] = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
index 7abb1bf7ce967..a8c32f72bfdeb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
@@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
*/
private[mllib] class BinaryClassificationPMMLModelExport(
- model : GeneralizedLinearModel,
- description : String,
- normalizationMethod : RegressionNormalizationMethodType,
+ model: GeneralizedLinearModel,
+ description: String,
+ normalizationMethod: RegressionNormalizationMethodType,
threshold: Double)
extends PMMLModelExport {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
index b5b824bb9c9b6..255c6140e5410 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
@@ -26,14 +26,14 @@ import org.apache.spark.mllib.clustering.KMeansModel
/**
* PMML Model Export for KMeansModel class
*/
-private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
+private[mllib] class KMeansPMMLModelExport(model: KMeansModel) extends PMMLModelExport{
populateKMeansPMML(model)
/**
* Export the input KMeansModel model to PMML format.
*/
- private def populateKMeansPMML(model : KMeansModel): Unit = {
+ private def populateKMeansPMML(model: KMeansModel): Unit = {
pmml.getHeader.setDescription("k-means clustering")
if (model.clusterCenters.length > 0) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index af1f7e74c004d..07ba0d8ccb2a8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -25,10 +25,10 @@ import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
-import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl._
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
@@ -600,7 +600,7 @@ object DecisionTree extends Serializable with Logging {
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
- val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+ val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 729a211574822..1b71256c585bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -22,8 +22,8 @@ import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index a684cdd18c2fc..570a76f960796 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -26,9 +26,9 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.Impurities
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 66f0908c1250f..b373c2de3ea96 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -83,7 +83,7 @@ class Node @Since("1.2.0") (
* @return predicted value
*/
@Since("1.1.0")
- def predict(features: Vector) : Double = {
+ def predict(features: Vector): Double = {
if (isLeaf) {
predict.predict
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index 094528e2ece06..240781bcd335b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -175,7 +175,7 @@ object LinearDataGenerator {
nfeatures: Int,
eps: Double,
nparts: Int = 2,
- intercept: Double = 0.0) : RDD[LabeledPoint] = {
+ intercept: Double = 0.0): RDD[LabeledPoint] = {
val random = new Random(42)
// Random values distributed uniformly in [-0.5, 0.5]
val w = Array.fill(nfeatures)(random.nextDouble() - 0.5)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index ee3c85d09a463..1a47344b68937 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -45,7 +45,7 @@ object SVMSuite {
nPoints: Int,
seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
- val weightsMat = new DoubleMatrix(1, weights.length, weights : _*)
+ val weightsMat = new DoubleMatrix(1, weights.length, weights: _*)
val x = Array.fill[Array[Double]](nPoints)(
Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0))
val y = x.map { xi =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
index 1142102bb040e..50441816ece3e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.util.random.XORShiftRandom
class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
- override def maxWaitTimeMillis : Int = 30000
+ override def maxWaitTimeMillis: Int = 30000
test("accuracy for null hypothesis using welch t-test") {
// set parameters
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 92ca0046d4f53..eda2b7307088f 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -55,6 +55,7 @@
com.google.guava
guava
+ compile
diff --git a/pom.xml b/pom.xml
index 9c975a45f8d23..fc5cf970e0601 100644
--- a/pom.xml
+++ b/pom.xml
@@ -152,9 +152,9 @@
1.7.7
hadoop2
0.7.1
- 1.4.0
+ 1.6.1
- 0.10.1
+ 0.10.2
4.3.2
@@ -167,7 +167,7 @@
${scala.version}
org.scala-lang
1.9.13
- 2.4.4
+ 2.5.3
1.1.2
1.1.2
1.2.0-incubating
@@ -226,93 +226,6 @@
false
-
- apache-repo
- Apache Repository
- https://repository.apache.org/content/repositories/releases
-
- true
-
-
- false
-
-
-
- jboss-repo
- JBoss Repository
- https://repository.jboss.org/nexus/content/repositories/releases
-
- true
-
-
- false
-
-
-
- mqtt-repo
- MQTT Repository
- https://repo.eclipse.org/content/repositories/paho-releases
-
- true
-
-
- false
-
-
-
- cloudera-repo
- Cloudera Repository
- https://repository.cloudera.com/artifactory/cloudera-repos
-
- true
-
-
- false
-
-
-
- spark-hive-staging
- Staging Repo for Hive 1.2.1 (Spark Version)
- https://oss.sonatype.org/content/repositories/orgspark-project-1113
-
- true
-
-
-
- mapr-repo
- MapR Repository
- http://repository.mapr.com/maven/
-
- true
-
-
- false
-
-
-
-
- spring-releases
- Spring Release Repository
- https://repo.spring.io/libs-release
-
- false
-
-
- false
-
-
-
-
- twttr-repo
- Twttr Repository
- http://maven.twttr.com
-
- true
-
-
- false
-
-
@@ -1133,6 +1046,12 @@
zookeeper
${zookeeper.version}
${hadoop.deps.scope}
+
+
+ org.jboss.netty
+ netty
+
+
org.codehaus.jackson
@@ -1858,7 +1777,7 @@
org.apache.maven.plugins
maven-enforcer-plugin
- 1.4
+ 1.4.1
enforce-versions
@@ -1873,6 +1792,19 @@
${java.version}
+
+
+
+ org.jboss.netty
+
+ true
+
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 9ba9f8286f10c..41856443af49b 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -91,11 +91,16 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
+ // The resolvers setting for MQTT Repository is needed for mqttv3(1.0.1)
+ // because spark-streaming-mqtt(1.6.0) depends on it.
+ // Remove the setting on updating previousSparkVersion.
val previousSparkVersion = "1.6.0"
val fullId = "spark-" + projectRef.project + "_2.10"
mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
- binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value))
+ binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value),
+ sbt.Keys.resolvers +=
+ "MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases")
}
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 5d4f19ab14a29..4c34c888cfd5e 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -141,7 +141,12 @@ object SparkBuild extends PomBuild {
publishMavenStyle := true,
unidocGenjavadocVersion := "0.9-spark0",
- resolvers += Resolver.mavenLocal,
+ // Override SBT's default resolvers:
+ resolvers := Seq(
+ DefaultMavenRepository,
+ Resolver.mavenLocal
+ ),
+ externalResolvers := resolvers.value,
otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))),
publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map {
(arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level)
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 15ba3a36d51ca..822a7c4a82d5e 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -1,9 +1,3 @@
-resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns)
-
-resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/"
-
-resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"
-
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0")
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 9714c46fe99a0..2439a1f715aba 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -187,6 +187,16 @@ This file is divided into 3 sections:
scala.collection.JavaConverters._ and use .asScala / .asJava methods
+
+
+ java,scala,3rdParty,spark
+ javax?\..*
+ scala\..*
+ (?!org\.apache\.spark\.).*
+ org\.apache\.spark\..*
+
+
+
@@ -207,17 +217,6 @@ This file is divided into 3 sections:
-
-
-
- java,scala,3rdParty,spark
- javax?\..*
- scala\..*
- (?!org\.apache\.spark\.).*
- org\.apache\.spark\..*
-
-
-
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
index cad770122d150..aabb5d49582c8 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
@@ -223,7 +223,12 @@ precedenceUnaryPrefixExpression
;
precedenceUnarySuffixExpression
- : precedenceUnaryPrefixExpression (a=KW_IS nullCondition)?
+ :
+ (
+ (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN
+ |
+ precedenceUnaryPrefixExpression (a=KW_IS nullCondition)?
+ )
-> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression)
-> precedenceUnaryPrefixExpression
;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index b04bb677774c5..2c13d3056f468 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -1,9 +1,9 @@
/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
+ (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -582,7 +582,7 @@ import java.util.HashMap;
return header;
}
-
+
@Override
public String getErrorMessage(RecognitionException e, String[] tokenNames) {
String msg = null;
@@ -619,7 +619,7 @@ import java.util.HashMap;
}
return msg;
}
-
+
public void pushMsg(String msg, RecognizerSharedState state) {
// ANTLR generated code does not wrap the @init code wit this backtracking check,
// even if the matching @after has it. If we have parser rules with that are doing
@@ -639,7 +639,7 @@ import java.util.HashMap;
// counter to generate unique union aliases
private int aliasCounter;
private String generateUnionAlias() {
- return "_u" + (++aliasCounter);
+ return "u_" + (++aliasCounter);
}
private char [] excludedCharForColumnName = {'.', ':'};
private boolean containExcludedCharForCreateTableColumnName(String input) {
@@ -1235,7 +1235,7 @@ alterTblPartitionStatementSuffixSkewedLocation
: KW_SET KW_SKEWED KW_LOCATION skewedLocations
-> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations)
;
-
+
skewedLocations
@init { pushMsg("skewed locations", state); }
@after { popMsg(state); }
@@ -1264,7 +1264,7 @@ alterStatementSuffixLocation
-> ^(TOK_ALTERTABLE_LOCATION $newLoc)
;
-
+
alterStatementSuffixSkewedby
@init {pushMsg("alter skewed by statement", state);}
@after{popMsg(state);}
@@ -1336,10 +1336,10 @@ tabTypeExpr
(identifier (DOT^
(
(KW_ELEM_TYPE) => KW_ELEM_TYPE
- |
+ |
(KW_KEY_TYPE) => KW_KEY_TYPE
- |
- (KW_VALUE_TYPE) => KW_VALUE_TYPE
+ |
+ (KW_VALUE_TYPE) => KW_VALUE_TYPE
| identifier
))*
)?
@@ -1376,7 +1376,7 @@ descStatement
analyzeStatement
@init { pushMsg("analyze statement", state); }
@after { popMsg(state); }
- : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN)
+ : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN)
| (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))?
-> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?)
;
@@ -1389,7 +1389,7 @@ showStatement
| KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)?
-> ^(TOK_SHOWCOLUMNS tableName $db_name?)
| KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?)
- | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?)
+ | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?)
| KW_SHOW KW_CREATE (
(KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name)
|
@@ -1398,7 +1398,7 @@ showStatement
| KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec?
-> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?)
| KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?)
- | KW_SHOW KW_LOCKS
+ | KW_SHOW KW_LOCKS
(
(KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?)
|
@@ -1511,7 +1511,7 @@ showCurrentRole
setRole
@init {pushMsg("set role", state);}
@after {popMsg(state);}
- : KW_SET KW_ROLE
+ : KW_SET KW_ROLE
(
(KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text])
|
@@ -1966,7 +1966,7 @@ columnNameOrderList
skewedValueElement
@init { pushMsg("skewed value element", state); }
@after { popMsg(state); }
- :
+ :
skewedColumnValues
| skewedColumnValuePairList
;
@@ -1980,8 +1980,8 @@ skewedColumnValuePairList
skewedColumnValuePair
@init { pushMsg("column value pair", state); }
@after { popMsg(state); }
- :
- LPAREN colValues=skewedColumnValues RPAREN
+ :
+ LPAREN colValues=skewedColumnValues RPAREN
-> ^(TOK_TABCOLVALUES $colValues)
;
@@ -2001,11 +2001,11 @@ skewedColumnValue
skewedValueLocationElement
@init { pushMsg("skewed value location element", state); }
@after { popMsg(state); }
- :
+ :
skewedColumnValue
| skewedColumnValuePair
;
-
+
columnNameOrder
@init { pushMsg("column name order", state); }
@after { popMsg(state); }
@@ -2118,7 +2118,7 @@ unionType
@after { popMsg(state); }
: KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList)
;
-
+
setOperator
@init { pushMsg("set operator", state); }
@after { popMsg(state); }
@@ -2172,7 +2172,7 @@ fromStatement[boolean topLevel]
{adaptor.create(Identifier, generateUnionAlias())}
)
)
- ^(TOK_INSERT
+ ^(TOK_INSERT
^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))
)
@@ -2414,8 +2414,8 @@ setColumnsClause
KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* )
;
-/*
- UPDATE
+/*
+ UPDATE
SET col1 = val1, col2 = val2... WHERE ...
*/
updateStatement
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index 1eda4a9a97644..2e3cc0bfde7c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -22,10 +22,10 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
-import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e362b55d80cd1..8a33af8207350 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -86,8 +86,7 @@ class Analyzer(
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
- PullOutNondeterministic,
- ComputeCurrentTime),
+ PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
@@ -1229,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}
-/**
- * Computes the current date and time to make sure we return the same result in a single query.
- */
-object ComputeCurrentTime extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- val dateExpr = CurrentDate()
- val timeExpr = CurrentTimestamp()
- val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
- val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
-
- plan transformAllExpressions {
- case CurrentDate() => currentDate
- case CurrentTimestamp() => currentTime
- }
- }
-}
-
/**
* Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index e8b2fcf819bf6..a8f89ce6de457 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -110,7 +110,9 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
- alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
+ alias
+ .map(a => Subquery(a, tableWithQualifiers))
+ .getOrElse(tableWithQualifiers)
}
override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d82d3edae4e38..6f199cfc5d8cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -931,6 +931,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
$evPrim = $result.copy();
"""
}
+
+ override def sql: String = dataType match {
+ // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
+ // type of casting can only be introduced by the analyzer, and can be omitted when converting
+ // back to SQL query string.
+ case _: ArrayType | _: MapType | _: StructType => child.sql
+ case _ => s"CAST(${child.sql} AS ${dataType.sql})"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 6a9c12127d367..d6219514b752b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -223,6 +224,15 @@ abstract class Expression extends TreeNode[Expression] {
protected def toCommentSafeString: String = this.toString
.replace("*/", "\\*\\/")
.replace("\\u", "\\\\u")
+
+ /**
+ * Returns SQL representation of this expression. For expressions that don't have a SQL
+ * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`.
+ */
+ @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation")
+ def sql: String = throw new UnsupportedOperationException(
+ s"Cannot map expression $this to its SQL representation"
+ )
}
@@ -356,6 +366,8 @@ abstract class UnaryExpression extends Expression {
"""
}
}
+
+ override def sql: String = s"($prettyName(${child.sql}))"
}
@@ -456,6 +468,8 @@ abstract class BinaryExpression extends Expression {
"""
}
}
+
+ override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}
@@ -492,6 +506,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
TypeCheckResult.TypeCheckSuccess
}
}
+
+ override def sql: String = s"(${left.sql} $symbol ${right.sql})"
}
@@ -593,4 +609,9 @@ abstract class TernaryExpression extends Expression {
"""
}
}
+
+ override def sql: String = {
+ val childrenSQL = children.map(_.sql).mkString(", ")
+ s"$prettyName($childrenSQL)"
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index f33833c3918df..827dce8af100e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -49,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
"org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
}
+ override def sql: String = prettyName
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index d0b78e15d99d1..94f8801dec369 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -78,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with
$countTerm++;
"""
}
+
+ override def prettyName: String = "monotonically_increasing_id"
+
+ override def sql: String = s"$prettyName()"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3add722da7816..1cb1b9da3049b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -24,9 +24,17 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
-abstract sealed class SortDirection
-case object Ascending extends SortDirection
-case object Descending extends SortDirection
+abstract sealed class SortDirection {
+ def sql: String
+}
+
+case object Ascending extends SortDirection {
+ override def sql: String = "ASC"
+}
+
+case object Descending extends SortDirection {
+ override def sql: String = "DESC"
+}
/**
* An expression that can be used to sort a tuple. This class extends expression primarily so that
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index b47f32d1768b9..ddd99c51ab0c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
/** The mode of an [[AggregateFunction]]. */
@@ -93,11 +94,13 @@ private[sql] case class AggregateExpression(
override def prettyString: String = aggregateFunction.prettyString
- override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
+ override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)"
+
+ override def sql: String = aggregateFunction.sql(isDistinct)
}
/**
- * AggregateFunction2 is the superclass of two aggregation function interfaces:
+ * AggregateFunction is the superclass of two aggregation function interfaces:
*
* - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of
* initialize(), update(), and merge() functions that operate on Row-based aggregation buffers.
@@ -163,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu
def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
}
+
+ def sql(isDistinct: Boolean): String = {
+ val distinct = if (isDistinct) "DISTINCT " else " "
+ s"$prettyName($distinct${children.map(_.sql).mkString(", ")})"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 61a17fd7db0fe..7bd851c059d0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
numeric.negate(input)
}
}
+
+ override def sql: String = s"(-${child.sql})"
}
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
@@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
defineCodeGen(ctx, ev, c => c)
protected override def nullSafeEval(input: Any): Any = input
+
+ override def sql: String = s"(+${child.sql})"
}
/**
@@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes
}
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
+
+ override def sql: String = s"$prettyName(${child.sql})"
}
abstract class BinaryArithmetic extends BinaryOperator {
@@ -513,4 +519,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
val r = a % n
if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
}
+
+ override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 9c73239f67ff2..5bd97cc7467ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -130,6 +130,8 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
}
})
}
+
+ override def sql: String = child.sql + s".`${childSchema(ordinal).name}`"
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index f79c8676fb58c..19da849d2bec9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils}
import org.apache.spark.sql.types._
@@ -74,6 +74,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
+
+ override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
}
trait CaseWhenLike extends Expression {
@@ -110,7 +112,7 @@ trait CaseWhenLike extends Expression {
override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
- thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
+ thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
}
}
@@ -206,6 +208,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
}
+
+ override def sql: String = {
+ val branchesSQL = branches.map(_.sql)
+ val (cases, maybeElse) = if (branches.length % 2 == 0) {
+ (branchesSQL, None)
+ } else {
+ (branchesSQL.init, Some(branchesSQL.last))
+ }
+
+ val head = s"CASE "
+ val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
+ val body = cases.grouped(2).map {
+ case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
+ }.mkString(" ")
+
+ head + body + tail
+ }
}
// scalastyle:off
@@ -310,6 +329,24 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
}
+
+ override def sql: String = {
+ val keySQL = key.sql
+ val branchesSQL = branches.map(_.sql)
+ val (cases, maybeElse) = if (branches.length % 2 == 0) {
+ (branchesSQL, None)
+ } else {
+ (branchesSQL.init, Some(branchesSQL.last))
+ }
+
+ val head = s"CASE $keySQL "
+ val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
+ val body = cases.grouped(2).map {
+ case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
+ }.mkString(" ")
+
+ head + body + tail
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 3d65946a1bc65..17f1df06f2fad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback {
override def eval(input: InternalRow): Any = {
DateTimeUtils.millisToDays(System.currentTimeMillis())
}
+
+ override def prettyName: String = "current_date"
}
/**
@@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
override def eval(input: InternalRow): Any = {
System.currentTimeMillis() * 1000L
}
+
+ override def prettyName: String = "current_timestamp"
}
/**
@@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression)
s"""${ev.value} = $sd + $d;"""
})
}
+
+ override def prettyName: String = "date_add"
}
/**
@@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression)
s"""${ev.value} = $sd - $d;"""
})
}
+
+ override def prettyName: String = "date_sub"
}
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
@@ -309,6 +317,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix
def this(time: Expression) = {
this(time, Literal("yyyy-MM-dd HH:mm:ss"))
}
+
+ override def prettyName: String = "to_unix_timestamp"
}
/**
@@ -332,6 +342,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi
def this() = {
this(CurrentTimestamp())
}
+
+ override def prettyName: String = "unix_timestamp"
}
abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
@@ -437,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
"""
}
}
+
+ override def prettyName: String = "unix_time"
}
/**
@@ -451,6 +465,8 @@ case class FromUnixTime(sec: Expression, format: Expression)
override def left: Expression = sec
override def right: Expression = format
+ override def prettyName: String = "from_unixtime"
+
def this(unix: Expression) = {
this(unix, Literal("yyyy-MM-dd HH:mm:ss"))
}
@@ -733,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
s"""$dtu.dateAddMonths($sd, $m)"""
})
}
+
+ override def prettyName: String = "add_months"
}
/**
@@ -758,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression)
s"""$dtu.monthsBetween($l, $r)"""
})
}
+
+ override def prettyName: String = "months_between"
}
/**
@@ -823,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, d => d)
}
+
+ override def prettyName: String = "to_date"
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index c54bcdd774021..5f8b544edb511 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -73,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
override def prettyName: String = "promote_precision"
+ override def sql: String = child.sql
}
/**
@@ -107,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
}
override def toString: String = s"CheckOverflow($child, $dataType)"
+
+ override def sql: String = child.sql
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 672cc9c45e0af..17351ef0685a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -214,6 +214,41 @@ case class Literal protected (value: Any, dataType: DataType)
}
}
}
+
+ override def sql: String = (value, dataType) match {
+ case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null =>
+ "NULL"
+
+ case _ if value == null =>
+ s"CAST(NULL AS ${dataType.sql})"
+
+ case (v: UTF8String, StringType) =>
+ // Escapes all backslashes and double quotes.
+ "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\""
+
+ case (v: Byte, ByteType) =>
+ s"CAST($v AS ${ByteType.simpleString.toUpperCase})"
+
+ case (v: Short, ShortType) =>
+ s"CAST($v AS ${ShortType.simpleString.toUpperCase})"
+
+ case (v: Long, LongType) =>
+ s"CAST($v AS ${LongType.simpleString.toUpperCase})"
+
+ case (v: Float, FloatType) =>
+ s"CAST($v AS ${FloatType.simpleString.toUpperCase})"
+
+ case (v: Decimal, DecimalType.Fixed(precision, scale)) =>
+ s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))"
+
+ case (v: Int, DateType) =>
+ s"DATE '${DateTimeUtils.toJavaDate(v)}'"
+
+ case (v: Long, TimestampType) =>
+ s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')"
+
+ case _ => value.toString
+ }
}
// TODO: Specialize
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 002f5929cc26b..66d8631a846ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -70,6 +70,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
}
+
+ override def sql: String = s"$name(${child.sql})"
}
abstract class UnaryLogExpression(f: Double => Double, name: String)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index fd95b124b2455..cc406a39f0408 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -220,4 +220,8 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
"""
}
+
+ override def prettyName: String = "hash"
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index eefd9c7482553..eee708cb02f9d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)(
explicitMetadata == a.explicitMetadata
case _ => false
}
+
+ override def sql: String = {
+ val qualifiersString =
+ if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
+ s"${child.sql} AS $qualifiersString`$name`"
+ }
}
/**
@@ -271,6 +277,12 @@ case class AttributeReference(
// Since the expression id is not in the first constructor it is missing from the default
// tree string.
override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}"
+
+ override def sql: String = {
+ val qualifiersString =
+ if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".")
+ s"$qualifiersString`$name`"
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index df4747d4e6f7a..89aec2b20fd0c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
"""
}.mkString("\n")
}
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
@@ -193,6 +195,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
ev.value = eval.isNull
eval.code
}
+
+ override def sql: String = s"(${child.sql} IS NULL)"
}
@@ -212,6 +216,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
ev.value = s"(!(${eval.isNull}))"
eval.code
}
+
+ override def sql: String = s"(${child.sql} IS NOT NULL)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 304b438c84ba4..bca12a8d21023 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -101,6 +101,8 @@ case class Not(child: Expression)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"!($c)")
}
+
+ override def sql: String = s"(NOT ${child.sql})"
}
@@ -176,6 +178,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
}
"""
}
+
+ override def sql: String = {
+ val childrenSQL = children.map(_.sql)
+ val valueSQL = childrenSQL.head
+ val listSQL = childrenSQL.tail.mkString(", ")
+ s"($valueSQL IN ($listSQL))"
+ }
}
/**
@@ -226,6 +235,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
"""
}
+
+ override def sql: String = {
+ val valueSQL = child.sql
+ val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ")
+ s"($valueSQL IN ($listSQL))"
+ }
}
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
@@ -274,6 +289,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
}
"""
}
+
+ override def sql: String = s"(${left.sql} AND ${right.sql})"
}
@@ -323,6 +340,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
}
"""
}
+
+ override def sql: String = s"(${left.sql} OR ${right.sql})"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 8bde8cb9fe876..8de47e9ddc28d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -49,6 +49,9 @@ abstract class RDG extends LeafExpression with Nondeterministic {
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
+
+ // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed.
+ override def sql: String = s"$prettyName($seed)"
}
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index adef6050c3565..db266639b8560 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -59,6 +59,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
matches(regex, input1.asInstanceOf[UTF8String].toString)
}
}
+
+ override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 50c8b9d59847e..931f752b4dc1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -61,6 +62,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
@@ -153,6 +156,8 @@ case class ConcatWs(children: Seq[Expression])
"""
}
}
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})"
}
trait String2StringExpression extends ImplicitCastInputTypes {
@@ -292,24 +297,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
val termDict = ctx.freshName("dict")
val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
- ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;")
- ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;")
- ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;")
+ ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;")
+ ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;")
+ ctx.addMutableState(classNameDict, termDict, s"$termDict = null;")
nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
val check = if (matchingExpr.foldable && replaceExpr.foldable) {
- s"${termDict} == null"
+ s"$termDict == null"
} else {
- s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})"
+ s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)"
}
s"""if ($check) {
// Not all of them is literal or matching or replace value changed
- ${termLastMatching} = ${matching}.clone();
- ${termLastReplace} = ${replace}.clone();
- ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate
- .buildDict(${termLastMatching}, ${termLastReplace});
+ $termLastMatching = $matching.clone();
+ $termLastReplace = $replace.clone();
+ $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate
+ .buildDict($termLastMatching, $termLastReplace);
}
- ${ev.value} = ${src}.translate(${termDict});
+ ${ev.value} = $src.translate($termDict);
"""
})
}
@@ -340,6 +345,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi
}
override def dataType: DataType = IntegerType
+
+ override def prettyName: String = "find_in_set"
}
/**
@@ -832,7 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
org.apache.commons.codec.binary.Base64.encodeBase64($child));
"""})
}
-
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0b1c74293bb8b..f8121a733a8d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -37,6 +37,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
+ Batch("Compute Current Time", Once,
+ ComputeCurrentTime) ::
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
@@ -333,6 +335,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
)
Project(cleanedProjection, child)
}
+
+ // TODO Eliminate duplicate code
+ // This clause is identical to the one above except that the inner operator is an `Aggregate`
+ // rather than a `Project`.
+ case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) =>
+ // Create a map of Aliases to their values from the child projection.
+ // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
+ val aliasMap = AttributeMap(projectList2.collect {
+ case a: Alias => (a.toAttribute, a)
+ })
+
+ // We only collapse these two Projects if their overlapped expressions are all
+ // deterministic.
+ val hasNondeterministic = projectList1.exists(_.collect {
+ case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
+ }.exists(!_.deterministic))
+
+ if (hasNondeterministic) {
+ p
+ } else {
+ // Substitute any attributes that are produced by the child projection, so that we safely
+ // eliminate it.
+ // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
+ // TODO: Fix TransformBase to avoid the cast below.
+ val substitutedProjection = projectList1.map(_.transform {
+ case a: Attribute => aliasMap.getOrElse(a, a)
+ }).asInstanceOf[Seq[NamedExpression]]
+ // collapse 2 projects may introduce unnecessary Aliases, trim them here.
+ val cleanedProjection = substitutedProjection.map(p =>
+ CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
+ )
+ agg.copy(aggregateExpressions = cleanedProjection)
+ }
}
}
@@ -976,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
a.copy(groupingExpressions = newGrouping)
}
}
+
+/**
+ * Computes the current date and time to make sure we return the same result in a single query.
+ */
+object ComputeCurrentTime extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val dateExpr = CurrentDate()
+ val timeExpr = CurrentTimestamp()
+ val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
+ val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)
+
+ plan transformAllExpressions {
+ case CurrentDate() => currentDate
+ case CurrentTimestamp() => currentTime
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 77dec7ca6e2b5..a5f6764aef7ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -37,14 +37,26 @@ object JoinType {
}
}
-sealed abstract class JoinType
+sealed abstract class JoinType {
+ def sql: String
+}
-case object Inner extends JoinType
+case object Inner extends JoinType {
+ override def sql: String = "INNER"
+}
-case object LeftOuter extends JoinType
+case object LeftOuter extends JoinType {
+ override def sql: String = "LEFT OUTER"
+}
-case object RightOuter extends JoinType
+case object RightOuter extends JoinType {
+ override def sql: String = "RIGHT OUTER"
+}
-case object FullOuter extends JoinType
+case object FullOuter extends JoinType {
+ override def sql: String = "FULL OUTER"
+}
-case object LeftSemi extends JoinType
+case object LeftSemi extends JoinType {
+ override def sql: String = "LEFT SEMI"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 79759b5a37b34..64957db6b4013 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -423,6 +423,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
+
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 62ea731ab5f38..9ebacb4680dc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -37,7 +37,7 @@ object RuleExecutor {
val maxSize = map.keys.map(_.toString.length).max
map.toSeq.sortBy(_._2).reverseMap { case (k, v) =>
s"${k.padTo(maxSize, " ").mkString} $v"
- }.mkString("\n")
+ }.mkString("\n", "\n", "")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index 71293475ca0f9..7a0d0de6328a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -130,6 +130,20 @@ package object util {
ret
}
+ /**
+ * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`.
+ */
+ def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match {
+ case xs if xs.isEmpty =>
+ Option(Seq.empty[T])
+
+ case xs =>
+ for {
+ head <- xs.head
+ tail <- sequenceOption(xs.tail)
+ } yield head +: tail
+ }
+
/* FIX ME
implicit class debugLogging(a: Any) {
def debugLogging() {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 6533622492d41..520e344361625 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override def simpleString: String = s"array<${elementType.simpleString}>"
+ override def sql: String = s"ARRAY<${elementType.sql}>"
+
override private[spark] def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 136a97e066df7..92cf8d4c46bda 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -65,6 +65,8 @@ abstract class DataType extends AbstractDataType {
/** Readable string representation for the type with truncation */
private[sql] def simpleString(maxNumberFields: Int): String = simpleString
+ def sql: String = simpleString.toUpperCase
+
/**
* Check if `this` and `other` are the same data type when ignoring nullability
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 00461e529ca0a..5474954af70e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -62,6 +62,8 @@ case class MapType(
override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
+ override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>"
+
override private[spark] def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 34382bf124eb0..3bd733fa2d26c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -25,8 +25,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
-import org.apache.spark.sql.catalyst.util.{LegacyTypeStringParser, DataTypeParser}
-
+import org.apache.spark.sql.catalyst.util.{DataTypeParser, LegacyTypeStringParser}
/**
* :: DeveloperApi ::
@@ -279,6 +278,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
s"struct<${fieldTypes.mkString(",")}>"
}
+ override def sql: String = {
+ val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}")
+ s"STRUCT<${fieldTypes.mkString(", ")}>"
+ }
+
private[sql] override def simpleString(maxNumberFields: Int): String = {
val builder = new StringBuilder
val fieldTypes = fields.take(maxNumberFields).map {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 4305903616bd9..d7a2c23be8a9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
+
+ override def sql: String = sqlType.sql
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
index 30978d9b49e2b..d7204c3488313 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
@@ -20,17 +20,33 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.plans.PlanTest
class CatalystQlSuite extends PlanTest {
+ val parser = new CatalystQl()
test("parse union/except/intersect") {
- val paresr = new CatalystQl()
- paresr.createPlan("select * from t1 union all select * from t2")
- paresr.createPlan("select * from t1 union distinct select * from t2")
- paresr.createPlan("select * from t1 union select * from t2")
- paresr.createPlan("select * from t1 except select * from t2")
- paresr.createPlan("select * from t1 intersect select * from t2")
- paresr.createPlan("(select * from t1) union all (select * from t2)")
- paresr.createPlan("(select * from t1) union distinct (select * from t2)")
- paresr.createPlan("(select * from t1) union (select * from t2)")
- paresr.createPlan("select * from ((select * from t1) union (select * from t2)) t")
+ parser.createPlan("select * from t1 union all select * from t2")
+ parser.createPlan("select * from t1 union distinct select * from t2")
+ parser.createPlan("select * from t1 union select * from t2")
+ parser.createPlan("select * from t1 except select * from t2")
+ parser.createPlan("select * from t1 intersect select * from t2")
+ parser.createPlan("(select * from t1) union all (select * from t2)")
+ parser.createPlan("(select * from t1) union distinct (select * from t2)")
+ parser.createPlan("(select * from t1) union (select * from t2)")
+ parser.createPlan("select * from ((select * from t1) union (select * from t2)) t")
+ }
+
+ test("window function: better support of parentheses") {
+ parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
+ "order by 2) from windowData")
+ parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
+ "order by 2) from windowData")
+ parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
+ "order by 2) from windowData")
+
+ parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
+ "from windowData")
+ parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
+ "from windowData")
+ parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
+ "from windowData")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index fa823e3021835..cf84855885a37 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
class AnalysisSuite extends AnalysisTest {
@@ -238,43 +237,6 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}
- test("analyzer should replace current_timestamp with literals") {
- val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
- LocalRelation())
-
- val min = System.currentTimeMillis() * 1000
- val plan = in.analyze.asInstanceOf[Project]
- val max = (System.currentTimeMillis() + 1) * 1000
-
- val lits = new scala.collection.mutable.ArrayBuffer[Long]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[Long]
- e
- }
- assert(lits.size == 2)
- assert(lits(0) >= min && lits(0) <= max)
- assert(lits(1) >= min && lits(1) <= max)
- assert(lits(0) == lits(1))
- }
-
- test("analyzer should replace current_date with literals") {
- val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
-
- val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
- val plan = in.analyze.asInstanceOf[Project]
- val max = DateTimeUtils.millisToDays(System.currentTimeMillis())
-
- val lits = new scala.collection.mutable.ArrayBuffer[Int]
- plan.transformAllExpressions { case e: Literal =>
- lits += e.value.asInstanceOf[Int]
- e
- }
- assert(lits.size == 2)
- assert(lits(0) >= min && lits(0) <= max)
- assert(lits(1) >= min && lits(1) <= max)
- assert(lits(0) == lits(1))
- }
-
test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
new file mode 100644
index 0000000000000..10ed4e46ddd1c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+
+class ComputeCurrentTimeSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
+ }
+
+ test("analyzer should replace current_timestamp with literals") {
+ val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
+ LocalRelation())
+
+ val min = System.currentTimeMillis() * 1000
+ val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
+ val max = (System.currentTimeMillis() + 1) * 1000
+
+ val lits = new scala.collection.mutable.ArrayBuffer[Long]
+ plan.transformAllExpressions { case e: Literal =>
+ lits += e.value.asInstanceOf[Long]
+ e
+ }
+ assert(lits.size == 2)
+ assert(lits(0) >= min && lits(0) <= max)
+ assert(lits(1) >= min && lits(1) <= max)
+ assert(lits(0) == lits(1))
+ }
+
+ test("analyzer should replace current_date with literals") {
+ val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())
+
+ val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
+ val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
+ val max = DateTimeUtils.millisToDays(System.currentTimeMillis())
+
+ val lits = new scala.collection.mutable.ArrayBuffer[Int]
+ plan.transformAllExpressions { case e: Literal =>
+ lits += e.value.asInstanceOf[Int]
+ e
+ }
+ assert(lits.size == 2)
+ assert(lits(0) >= min && lits(0) <= max)
+ assert(lits(1) >= min && lits(1) <= max)
+ assert(lits(0) == lits(1))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index b998636909a7d..f9f3bd55aa578 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer =
testRelation
.select('a)
- .groupBy('a)('a)
- .select('a).analyze
+ .groupBy('a)('a).analyze
comparePlans(optimized, correctAnswer)
}
@@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest {
val correctAnswer =
testRelation
.select('a)
- .groupBy('a)('a as 'c)
- .select('c).analyze
+ .groupBy('a)('a as 'c).analyze
comparePlans(optimized, correctAnswer)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 21a6fba9078df..2355de3d05865 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -165,7 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val buf = new ArrayBuffer[InternalRow]
val totalParts = childRDD.partitions.length
- var partsScanned = 0L
+ var partsScanned = 0
while (buf.size < n && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
@@ -183,10 +183,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = n - buf.size
- val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt
+ val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
val sc = sqlContext.sparkContext
- val res =
- sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
+ val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
res.foreach(buf ++= _.take(n - buf.size))
partsScanned += p.size
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
index a322688a259e2..f3e89ef4a71f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
@@ -16,10 +16,10 @@
*/
package org.apache.spark.sql.execution
+import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
-import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier}
private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) {
/** Check if a command should not be explained. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index be885397a7d40..168b5ab0316d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -22,6 +22,7 @@ import java.util
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -29,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
-import org.apache.spark.{SparkEnv, TaskContext}
/**
* This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 4f8524f4b967c..fff72872c13b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
-import org.apache.spark.sql.types.{IntegerType, StructType, StringType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.SerializableConfiguration
@@ -349,67 +349,6 @@ private[sql] class DynamicPartitionWriterContainer(
}
}
- private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = {
- val bucketIdIndex = partitionColumns.length
- if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) {
- false
- } else {
- var i = partitionColumns.length - 1
- while (i >= 0) {
- val dt = partitionColumns(i).dataType
- if (key1.get(i, dt) != key2.get(i, dt)) return false
- i -= 1
- }
- true
- }
- }
-
- private def sortBasedWrite(
- sorter: UnsafeKVExternalSorter,
- iterator: Iterator[InternalRow],
- getSortingKey: UnsafeProjection,
- getOutputRow: UnsafeProjection,
- getPartitionString: UnsafeProjection,
- outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = {
- while (iterator.hasNext) {
- val currentRow = iterator.next()
- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
- }
-
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
-
- val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
- (key1, key2) => key1 != key2
- } else {
- (key1, key2) => key1 == null || !sameBucket(key1, key2)
- }
-
- val sortedIterator = sorter.sortedIterator()
- var currentKey: UnsafeRow = null
- var currentWriter: OutputWriter = null
- try {
- while (sortedIterator.next()) {
- if (needNewWriter(currentKey, sortedIterator.getKey)) {
- if (currentWriter != null) {
- currentWriter.close()
- }
- currentKey = sortedIterator.getKey.copy()
- logDebug(s"Writing partition: $currentKey")
-
- // Either use an existing file from before, or open a new one.
- currentWriter = outputWriters.remove(currentKey)
- if (currentWriter == null) {
- currentWriter = newOutputWriter(currentKey, getPartitionString)
- }
- }
-
- currentWriter.writeInternal(sortedIterator.getValue)
- }
- } finally {
- if (currentWriter != null) { currentWriter.close() }
- }
- }
-
/**
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
@@ -435,22 +374,18 @@ private[sql] class DynamicPartitionWriterContainer(
}
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
executorSideSetup(taskContext)
- var outputWritersCleared = false
-
// We should first sort by partition columns, then bucket id, and finally sorting columns.
- val getSortingKey =
- UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)
-
- val sortingKeySchema = if (bucketSpec.isEmpty) {
- StructType.fromAttributes(partitionColumns)
- } else { // If it's bucketed, we should also consider bucket id as part of the key.
- val fields = StructType.fromAttributes(partitionColumns)
- .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns)
- StructType(fields)
- }
+ val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
+
+ val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
+
+ val sortingKeySchema = StructType(sortingExpressions.map {
+ case a: Attribute => StructField(a.name, a.dataType, a.nullable)
+ // The sorting expressions are all `Attribute` except bucket id.
+ case _ => StructField("bucketId", IntegerType, nullable = false)
+ })
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
@@ -461,54 +396,49 @@ private[sql] class DynamicPartitionWriterContainer(
// If anything below fails, we should abort the task.
try {
- // If there is no sorting columns, we set sorter to null and try the hash-based writing first,
- // and fill the sorter if there are too many writers and we need to fall back on sorting.
- // If there are sorting columns, then we have to sort the data anyway, and no need to try the
- // hash-based writing first.
- var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
- new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(dataColumns),
- SparkEnv.get.blockManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
+ // Sorts the data before write, so that we only need one writer at the same time.
+ // TODO: inject a local sort operator in planning.
+ val sorter = new UnsafeKVExternalSorter(
+ sortingKeySchema,
+ StructType.fromAttributes(dataColumns),
+ SparkEnv.get.blockManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes)
+
+ while (iterator.hasNext) {
+ val currentRow = iterator.next()
+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+ }
+
+ logInfo(s"Sorting complete. Writing out partition files one at a time.")
+
+ val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
+ identity
} else {
- null
+ UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
+ case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
+ })
}
- while (iterator.hasNext && sorter == null) {
- val inputRow = iterator.next()
- // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
- val currentKey = getSortingKey(inputRow)
- var currentWriter = outputWriters.get(currentKey)
-
- if (currentWriter == null) {
- if (outputWriters.size < maxOpenFiles) {
+
+ val sortedIterator = sorter.sortedIterator()
+ var currentKey: UnsafeRow = null
+ var currentWriter: OutputWriter = null
+ try {
+ while (sortedIterator.next()) {
+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+ if (currentKey != nextKey) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ }
+ currentKey = nextKey.copy()
+ logDebug(s"Writing partition: $currentKey")
+
currentWriter = newOutputWriter(currentKey, getPartitionString)
- outputWriters.put(currentKey.copy(), currentWriter)
- currentWriter.writeInternal(getOutputRow(inputRow))
- } else {
- logInfo(s"Maximum partitions reached, falling back on sorting.")
- sorter = new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(dataColumns),
- SparkEnv.get.blockManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
- sorter.insertKV(currentKey, getOutputRow(inputRow))
}
- } else {
- currentWriter.writeInternal(getOutputRow(inputRow))
- }
- }
- // If the sorter is not null that means that we reached the maxFiles above and need to finish
- // using external sort, or there are sorting columns and we need to sort the whole data set.
- if (sorter != null) {
- sortBasedWrite(
- sorter,
- iterator,
- getSortingKey,
- getOutputRow,
- getPartitionString,
- outputWriters)
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ } finally {
+ if (currentWriter != null) { currentWriter.close() }
}
commitTask()
@@ -518,31 +448,5 @@ private[sql] class DynamicPartitionWriterContainer(
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
-
- def clearOutputWriters(): Unit = {
- if (!outputWritersCleared) {
- outputWriters.asScala.values.foreach(_.close())
- outputWriters.clear()
- outputWritersCleared = true
- }
- }
-
- def commitTask(): Unit = {
- try {
- clearOutputWriters()
- super.commitTask()
- } catch {
- case cause: Throwable =>
- throw new RuntimeException("Failed to commit task", cause)
- }
- }
-
- def abortTask(): Unit = {
- try {
- clearOutputWriters()
- } finally {
- super.abortTask()
- }
- }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
index 82287c8967134..9976829638d70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.execution.datasources
import org.apache.hadoop.mapreduce.TaskAttemptContext
+
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation}
+import org.apache.spark.sql.sources.{HadoopFsRelation, HadoopFsRelationProvider, OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
index 2e3fe3da15389..b2f5c1e96421d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
@@ -90,7 +90,7 @@ object JacksonParser {
DateTimeUtils.stringToTime(parser.getText).getTime * 1000L
case (VALUE_NUMBER_INT, TimestampType) =>
- parser.getLongValue * 1000L
+ parser.getLongValue * 1000000L
case (_, StringType) =>
val writer = new ByteArrayOutputStream()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 4b375de05e9e3..7754edc803d10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -44,9 +44,9 @@ import org.apache.spark.{Logging, Partition => SparkPartition, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier}
import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -147,6 +147,12 @@ private[sql] class ParquetRelation(
.get(ParquetRelation.METASTORE_SCHEMA)
.map(DataType.fromJson(_).asInstanceOf[StructType])
+ // If this relation is converted from a Hive metastore table, this method returns the name of the
+ // original Hive metastore table.
+ private[sql] def metastoreTableName: Option[TableIdentifier] = {
+ parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier)
+ }
+
private lazy val metadataCache: MetadataCache = {
val meta = new MetadataCache
meta.refresh()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index d484403d1c641..1c773e69275db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index c35f33132f602..9f3607369c30f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -162,7 +162,6 @@ trait HadoopFsRelationProvider {
partitionColumns: Option[StructType],
parameters: Map[String, String]): HadoopFsRelation
- // TODO: expose bucket API to users.
private[sql] def createRelation(
sqlContext: SQLContext,
paths: Array[String],
@@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable {
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter
- // TODO: expose bucket API to users.
private[sql] def newInstance(
path: String,
bucketId: Option[Int],
@@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql](
private var _partitionSpec: PartitionSpec = _
- // TODO: expose bucket API to users.
private[sql] def bucketSpec: Option[BucketSpec] = None
private class FileStatusCache {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ade1391ecd74a..983dfbdedeefe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -308,6 +308,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(
mapData.toDF().limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
+
+ // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake
+ checkAnswer(
+ sqlContext.range(2).limit(2147483638),
+ Row(0) :: Row(1) :: Nil
+ )
}
test("except") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bd987ae1bb03a..5de0979606b88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2067,16 +2067,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}
-
- test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") {
- val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 )
- rdd.toDF("key").registerTempTable("spark12340")
- checkAnswer(
- sql("select key from spark12340 limit 2147483638"),
- Row(1) :: Row(2) :: Row(3) :: Nil
- )
- assert(rdd.take(2147483638).size === 3)
- assert(rdd.takeAsync(2147483638).get.size === 3)
- }
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index b3b6b7df0c1d1..4ab148065a476 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -83,9 +83,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
- checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)),
+ checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber * 1000L)),
enforceCorrectType(intNumber, TimestampType))
- checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)),
+ checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong * 1000L)),
enforceCorrectType(intNumber.toLong, TimestampType))
val strTime = "2014-09-30 12:34:56"
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)),
@@ -1465,4 +1465,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
}
+ test("Casting long as timestamp") {
+ withTempTable("jsonTable") {
+ val schema = (new StructType).add("ts", TimestampType)
+ val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong)
+
+ jsonDF.registerTempTable("jsonTable")
+
+ checkAnswer(
+ sql("select ts from jsonTable"),
+ Row(java.sql.Timestamp.valueOf("2016-01-02 03:04:05"))
+ )
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index cb61f7eeca0de..a0836058d3c74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -205,6 +205,10 @@ private[json] trait TestJsonData {
"""{"b": [{"c": {}}]}""" ::
"""]""" :: Nil)
+ def timestampAsLong: RDD[String] =
+ sqlContext.sparkContext.parallelize(
+ """{"ts":1451732645}""" :: Nil)
+
lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]())
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
index cab6abde6da23..ae95b50e1ee76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala
@@ -21,9 +21,9 @@ import java.io.File
import scala.collection.JavaConverters._
import scala.util.Try
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLConf, SQLContext}
import org.apache.spark.util.{Benchmark, Utils}
-import org.apache.spark.{SparkConf, SparkContext}
/**
* Benchmark to measure parquet read performance.
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index bd1a52e5f3303..afd2f611580fc 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -41,9 +41,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
- def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
+ def testCases: Seq[(String, File)] = {
+ hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
+ }
override def beforeAll() {
+ super.beforeAll()
TestHive.cacheTables = true
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
@@ -68,10 +71,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// For debugging dump some statistics about how much time was spent in various optimizer rules.
logWarning(RuleExecutor.dumpTimeSpent())
+ super.afterAll()
}
/** A list of tests deemed out of scope currently and thus completely disregarded. */
- override def blackList = Seq(
+ override def blackList: Seq[String] = Seq(
// These tests use hooks that are not on the classpath and thus break all subsequent execution.
"hook_order",
"hook_context_cs",
@@ -106,7 +110,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"alter_merge",
"alter_concatenate_indexed_table",
"protectmode2",
- //"describe_table",
+ // "describe_table",
"describe_comment_nonascii",
"create_merge_compressed",
@@ -323,7 +327,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
* The set of tests that are believed to be working in catalyst. Tests not on whiteList or
* blacklist are implicitly marked as ignored.
*/
- override def whiteList = Seq(
+ override def whiteList: Seq[String] = Seq(
"add_part_exist",
"add_part_multiple",
"add_partition_no_whitelist",
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 98bbdf0653c2a..bad3ca6da231f 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
TestHive.reset()
+ super.afterAll()
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index bf3fe12d5c5d2..d1b1c0d8d8bc2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -24,11 +24,12 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
+import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.parse.EximUtil
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions._
@@ -38,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SparkQl
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.hive.client._
-import org.apache.spark.sql.hive.execution.{HiveNativeCommand, AnalyzeTable, DropTable, HiveScriptIOSchema}
+import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema}
import org.apache.spark.sql.types._
import org.apache.spark.sql.AnalysisException
@@ -668,7 +669,8 @@ private[hive] object HiveQl extends SparkQl with Logging {
Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse(
sys.error(s"Couldn't find function $functionName"))
val functionClassName = functionInfo.getFunctionClass.getName
- HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr))
+ HiveGenericUDTF(
+ functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr))
case other => super.nodeToGenerator(node)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
new file mode 100644
index 0000000000000..61e3f183bb42d
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder}
+import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
+
+/**
+ * A builder class used to convert a resolved logical plan into a SQL query string. Note that this
+ * all resolved logical plan are convertible. They either don't have corresponding SQL
+ * representations (e.g. logical plans that operate on local Scala collections), or are simply not
+ * supported by this builder (yet).
+ */
+class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
+ def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext)
+
+ def toSQL: Option[String] = {
+ val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
+ val maybeSQL = try {
+ toSQL(canonicalizedPlan)
+ } catch { case cause: UnsupportedOperationException =>
+ logInfo(s"Failed to build SQL query string because: ${cause.getMessage}")
+ None
+ }
+
+ if (maybeSQL.isDefined) {
+ logDebug(
+ s"""Built SQL query string successfully from given logical plan:
+ |
+ |# Original logical plan:
+ |${logicalPlan.treeString}
+ |# Canonicalized logical plan:
+ |${canonicalizedPlan.treeString}
+ |# Built SQL query string:
+ |${maybeSQL.get}
+ """.stripMargin)
+ } else {
+ logDebug(
+ s"""Failed to build SQL query string from given logical plan:
+ |
+ |# Original logical plan:
+ |${logicalPlan.treeString}
+ |# Canonicalized logical plan:
+ |${canonicalizedPlan.treeString}
+ """.stripMargin)
+ }
+
+ maybeSQL
+ }
+
+ private def projectToSQL(
+ projectList: Seq[NamedExpression],
+ child: LogicalPlan,
+ isDistinct: Boolean): Option[String] = {
+ for {
+ childSQL <- toSQL(child)
+ listSQL = projectList.map(_.sql).mkString(", ")
+ maybeFrom = child match {
+ case OneRowRelation => " "
+ case _ => " FROM "
+ }
+ distinct = if (isDistinct) " DISTINCT " else " "
+ } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL"
+ }
+
+ private def aggregateToSQL(
+ groupingExprs: Seq[Expression],
+ aggExprs: Seq[Expression],
+ child: LogicalPlan): Option[String] = {
+ val aggSQL = aggExprs.map(_.sql).mkString(", ")
+ val groupingSQL = groupingExprs.map(_.sql).mkString(", ")
+ val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY "
+ val maybeFrom = child match {
+ case OneRowRelation => " "
+ case _ => " FROM "
+ }
+
+ toSQL(child).map { childSQL =>
+ s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL"
+ }
+ }
+
+ private def toSQL(node: LogicalPlan): Option[String] = node match {
+ case Distinct(Project(list, child)) =>
+ projectToSQL(list, child, isDistinct = true)
+
+ case Project(list, child) =>
+ projectToSQL(list, child, isDistinct = false)
+
+ case Aggregate(groupingExprs, aggExprs, child) =>
+ aggregateToSQL(groupingExprs, aggExprs, child)
+
+ case Limit(limit, child) =>
+ for {
+ childSQL <- toSQL(child)
+ limitSQL = limit.sql
+ } yield s"$childSQL LIMIT $limitSQL"
+
+ case Filter(condition, child) =>
+ for {
+ childSQL <- toSQL(child)
+ whereOrHaving = child match {
+ case _: Aggregate => "HAVING"
+ case _ => "WHERE"
+ }
+ conditionSQL = condition.sql
+ } yield s"$childSQL $whereOrHaving $conditionSQL"
+
+ case Union(left, right) =>
+ for {
+ leftSQL <- toSQL(left)
+ rightSQL <- toSQL(right)
+ } yield s"$leftSQL UNION ALL $rightSQL"
+
+ // ParquetRelation converted from Hive metastore table
+ case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) =>
+ // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is
+ // that, the metastore database name and table name are not always propagated to converted
+ // `ParquetRelation` instances via data source options. Here we use subquery alias as a
+ // workaround.
+ Some(s"`$alias`")
+
+ case Subquery(alias, child) =>
+ toSQL(child).map(childSQL => s"($childSQL) AS $alias")
+
+ case Join(left, right, joinType, condition) =>
+ for {
+ leftSQL <- toSQL(left)
+ rightSQL <- toSQL(right)
+ joinTypeSQL = joinType.sql
+ conditionSQL = condition.map(" ON " + _.sql).getOrElse("")
+ } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL"
+
+ case MetastoreRelation(database, table, alias) =>
+ val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("")
+ Some(s"`$database`.`$table`$aliasSQL")
+
+ case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
+ if orders.map(_.child) == partitionExprs =>
+ for {
+ childSQL <- toSQL(child)
+ partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
+ } yield s"$childSQL CLUSTER BY $partitionExprsSQL"
+
+ case Sort(orders, global, child) =>
+ for {
+ childSQL <- toSQL(child)
+ ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ")
+ orderOrSort = if (global) "ORDER" else "SORT"
+ } yield s"$childSQL $orderOrSort BY $ordersSQL"
+
+ case RepartitionByExpression(partitionExprs, child, _) =>
+ for {
+ childSQL <- toSQL(child)
+ partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ")
+ } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL"
+
+ case OneRowRelation =>
+ Some("")
+
+ case _ => None
+ }
+
+ object Canonicalizer extends RuleExecutor[LogicalPlan] {
+ override protected def batches: Seq[Batch] = Seq(
+ Batch("Canonicalizer", FixedPoint(100),
+ // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
+ // `Aggregate`s to perform type casting. This rule merges these `Project`s into
+ // `Aggregate`s.
+ ProjectCollapsing,
+
+ // Used to handle other auxiliary `Project`s added by analyzer (e.g.
+ // `ResolveAggregateFunctions` rule)
+ RecoverScopingInfo
+ )
+ )
+
+ object RecoverScopingInfo extends Rule[LogicalPlan] {
+ override def apply(tree: LogicalPlan): LogicalPlan = tree transform {
+ // This branch handles aggregate functions within HAVING clauses. For example:
+ //
+ // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255"
+ //
+ // This kind of query results in query plans of the following form because of analysis rule
+ // `ResolveAggregateFunctions`:
+ //
+ // Project ...
+ // +- Filter ...
+ // +- Aggregate ...
+ // +- MetastoreRelation default, src, None
+ case plan @ Project(_, Filter(_, _: Aggregate)) =>
+ wrapChildWithSubquery(plan)
+
+ case plan @ Project(_,
+ _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit
+ ) => plan
+
+ case plan: Project =>
+ wrapChildWithSubquery(plan)
+ }
+
+ def wrapChildWithSubquery(project: Project): Project = project match {
+ case Project(projectList, child) =>
+ val alias = SQLBuilder.newSubqueryName
+ val childAttributes = child.outputSet
+ val aliasedProjectList = projectList.map(_.transform {
+ case a: Attribute if childAttributes.contains(a) =>
+ a.withQualifiers(alias :: Nil)
+ }.asInstanceOf[NamedExpression])
+
+ Project(aliasedProjectList, Subquery(alias, child))
+ }
+ }
+ }
+}
+
+object SQLBuilder {
+ private val nextSubqueryId = new AtomicLong(0)
+
+ private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}"
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index b1a6d0ab7df3c..56cab1aee89df 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.hive
-import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.util.Try
import org.apache.hadoop.hive.ql.exec._
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
@@ -32,15 +31,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.Obje
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.{analysis, InternalRow}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.sequenceOption
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.client.ClientWrapper
import org.apache.spark.sql.types._
@@ -75,19 +71,19 @@ private[hive] class HiveFunctionRegistry(
try {
if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUDF(
- new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
+ name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
} else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children)
+ HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children)
+ HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children)
+ HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUDAFFunction(
- new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
+ name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children)
+ val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
udtf.elementTypes // Force it to check input data types.
udtf
} else {
@@ -137,7 +133,8 @@ private[hive] class HiveFunctionRegistry(
}
}
-private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+private[hive] case class HiveSimpleUDF(
+ name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
override def deterministic: Boolean = isUDFDeterministic
@@ -191,6 +188,8 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+
+ override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
@@ -205,7 +204,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
override def get(): AnyRef = wrap(func(), oi, dataType)
}
-private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+private[hive] case class HiveGenericUDF(
+ name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
override def nullable: Boolean = true
@@ -257,6 +257,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+
+ override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
/**
@@ -271,6 +273,7 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUDTF(
+ name: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression])
extends Generator with HiveInspectors with CodegenFallback {
@@ -336,6 +339,8 @@ private[hive] case class HiveGenericUDTF(
override def toString: String = {
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}
+
+ override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})"
}
/**
@@ -343,6 +348,7 @@ private[hive] case class HiveGenericUDTF(
* performance a lot.
*/
private[hive] case class HiveUDAFFunction(
+ name: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression],
isUDAFBridgeRequired: Boolean = false,
@@ -427,5 +433,9 @@ private[hive] case class HiveUDAFFunction(
override def supportsPartial: Boolean = false
override val dataType: DataType = inspectorToDataType(returnInspector)
-}
+ override def sql(isDistinct: Boolean): String = {
+ val distinct = if (isDistinct) "DISTINCT " else " "
+ s"$name($distinct${children.map(_.sql).mkString(", ")})"
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index d26cb48479066..033746d42f557 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -37,8 +37,8 @@ import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.hive.execution.HiveNativeCommand
import org.apache.spark.sql.hive.client.ClientWrapper
+import org.apache.spark.sql.hive.execution.HiveNativeCommand
import org.apache.spark.util.{ShutdownHookManager, Utils}
// SPARK-3729: Test key required to check for initialization errors with config.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index a2d283622ca52..e72a18a716b5c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -21,8 +21,8 @@ import scala.util.Try
import org.scalatest.BeforeAndAfter
-import org.apache.spark.sql.catalyst.parser.ParseDriver
import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.catalyst.parser.ParseDriver
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.hive.test.TestHiveSingleton
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
new file mode 100644
index 0000000000000..3a6eb57add4e3
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{If, Literal}
+
+class ExpressionSQLBuilderSuite extends SQLBuilderTest {
+ test("literal") {
+ checkSQL(Literal("foo"), "\"foo\"")
+ checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
+ checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)")
+ checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)")
+ checkSQL(Literal(4: Int), "4")
+ checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)")
+ checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
+ checkSQL(Literal(2.5D), "2.5")
+ checkSQL(
+ Literal(Timestamp.valueOf("2016-01-01 00:00:00")),
+ "TIMESTAMP('2016-01-01 00:00:00.0')")
+ // TODO tests for decimals
+ }
+
+ test("binary comparisons") {
+ checkSQL('a.int === 'b.int, "(`a` = `b`)")
+ checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")
+ checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))")
+
+ checkSQL('a.int < 'b.int, "(`a` < `b`)")
+ checkSQL('a.int <= 'b.int, "(`a` <= `b`)")
+ checkSQL('a.int > 'b.int, "(`a` > `b`)")
+ checkSQL('a.int >= 'b.int, "(`a` >= `b`)")
+
+ checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))")
+ checkSQL('a.int in (1, 2), "(`a` IN (1, 2))")
+
+ checkSQL('a.int.isNull, "(`a` IS NULL)")
+ checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)")
+ }
+
+ test("logical operators") {
+ checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)")
+ checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)")
+ checkSQL(!'a.boolean, "(NOT `a`)")
+ checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))")
+ }
+
+ test("arithmetic expressions") {
+ checkSQL('a.int + 'b.int, "(`a` + `b`)")
+ checkSQL('a.int - 'b.int, "(`a` - `b`)")
+ checkSQL('a.int * 'b.int, "(`a` * `b`)")
+ checkSQL('a.int / 'b.int, "(`a` / `b`)")
+ checkSQL('a.int % 'b.int, "(`a` % `b`)")
+
+ checkSQL(-'a.int, "(-`a`)")
+ checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))")
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
new file mode 100644
index 0000000000000..9a8a9c51183da
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SQLTestUtils
+
+class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ sqlContext.range(10).write.saveAsTable("t0")
+
+ sqlContext
+ .range(10)
+ .select('id as 'key, concat(lit("val_"), 'id) as 'value)
+ .write
+ .saveAsTable("t1")
+
+ sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
+ }
+
+ override protected def afterAll(): Unit = {
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+ }
+
+ private def checkHiveQl(hiveQl: String): Unit = {
+ val df = sql(hiveQl)
+ val convertedSQL = new SQLBuilder(df).toSQL
+
+ if (convertedSQL.isEmpty) {
+ fail(
+ s"""Cannot convert the following HiveQL query plan back to SQL query string:
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin)
+ }
+
+ val sqlString = convertedSQL.get
+ try {
+ checkAnswer(sql(sqlString), df)
+ } catch { case cause: Throwable =>
+ fail(
+ s"""Failed to execute converted SQL string or got wrong answer:
+ |
+ |# Converted SQL query string:
+ |$sqlString
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin,
+ cause)
+ }
+ }
+
+ test("in") {
+ checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)")
+ }
+
+ test("aggregate function in having clause") {
+ checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0")
+ }
+
+ test("aggregate function in order by clause") {
+ checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)")
+ }
+
+ // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule
+ // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into
+ // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query
+ // execution since these aliases have different expression ID. But this introduces name collision
+ // when converting resolved plans back to SQL query strings as expression IDs are stripped.
+ ignore("aggregate function in order by clause with multiple order keys") {
+ checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)")
+ }
+
+ test("type widening in union") {
+ checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0")
+ }
+
+ test("case") {
+ checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0")
+ }
+
+ test("case with else") {
+ checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0")
+ }
+
+ test("case with key") {
+ checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0")
+ }
+
+ test("case with key and else") {
+ checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0")
+ }
+
+ test("select distinct without aggregate functions") {
+ checkHiveQl("SELECT DISTINCT id FROM t0")
+ }
+
+ test("cluster by") {
+ checkHiveQl("SELECT id FROM t0 CLUSTER BY id")
+ }
+
+ test("distribute by") {
+ checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id")
+ }
+
+ test("distribute by with sort by") {
+ checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id")
+ }
+
+ test("distinct aggregation") {
+ checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0")
+ }
+
+ // TODO Enable this
+ // Query plans transformed by DistinctAggregationRewriter are not recognized yet
+ ignore("distinct and non-distinct aggregation") {
+ checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a")
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala
new file mode 100644
index 0000000000000..a5e209ac9db3b
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+
+abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
+ protected def checkSQL(e: Expression, expectedSQL: String): Unit = {
+ val actualSQL = e.sql
+ try {
+ assert(actualSQL === expectedSQL)
+ } catch {
+ case cause: Throwable =>
+ fail(
+ s"""Wrong SQL generated for the following expression:
+ |
+ |${e.prettyName}
+ |
+ |$cause
+ """.stripMargin)
+ }
+ }
+
+ protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = {
+ val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL
+
+ if (maybeSQL.isEmpty) {
+ fail(
+ s"""Cannot convert the following logical query plan to SQL:
+ |
+ |${plan.treeString}
+ """.stripMargin)
+ }
+
+ val actualSQL = maybeSQL.get
+
+ try {
+ assert(actualSQL === expectedSQL)
+ } catch {
+ case cause: Throwable =>
+ fail(
+ s"""Wrong SQL generated for the following logical query plan:
+ |
+ |${plan.treeString}
+ |
+ |$cause
+ """.stripMargin)
+ }
+
+ checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan))
+ }
+
+ protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
+ checkSQL(df.queryExecution.analyzed, expectedSQL)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index d7e8ebc8d312f..fd3339a66bec0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.{ExplainCommand, SetCommand}
import org.apache.spark.sql.execution.datasources.DescribeCommand
+import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder}
import org.apache.spark.sql.hive.test.TestHive
/**
@@ -130,6 +131,28 @@ abstract class HiveComparisonTest
new java.math.BigInteger(1, digest.digest).toString(16)
}
+ /** Used for testing [[SQLBuilder]] */
+ private var numConvertibleQueries: Int = 0
+ private var numTotalQueries: Int = 0
+
+ override protected def afterAll(): Unit = {
+ logInfo({
+ val percentage = if (numTotalQueries > 0) {
+ numConvertibleQueries.toDouble / numTotalQueries * 100
+ } else {
+ 0D
+ }
+
+ s"""SQLBuiler statistics:
+ |- Total query number: $numTotalQueries
+ |- Number of convertible queries: $numConvertibleQueries
+ |- Percentage of convertible queries: $percentage%
+ """.stripMargin
+ })
+
+ super.afterAll()
+ }
+
protected def prepareAnswer(
hiveQuery: TestHive.type#QueryExecution,
answer: Seq[String]): Seq[String] = {
@@ -372,8 +395,49 @@ abstract class HiveComparisonTest
// Run w/ catalyst
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
- val query = new TestHive.QueryExecution(queryString)
- try { (query, prepareAnswer(query, query.stringResult())) } catch {
+ var query: TestHive.QueryExecution = null
+ try {
+ query = {
+ val originalQuery = new TestHive.QueryExecution(queryString)
+ val containsCommands = originalQuery.analyzed.collectFirst {
+ case _: Command => ()
+ case _: LogicalInsertIntoHiveTable => ()
+ }.nonEmpty
+
+ if (containsCommands) {
+ originalQuery
+ } else {
+ numTotalQueries += 1
+ new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql =>
+ numConvertibleQueries += 1
+ logInfo(
+ s"""
+ |### Running SQL generation round-trip test {{{
+ |${originalQuery.analyzed.treeString}
+ |Original SQL:
+ |$queryString
+ |
+ |Generated SQL:
+ |$sql
+ |}}}
+ """.stripMargin.trim)
+ new TestHive.QueryExecution(sql)
+ }.getOrElse {
+ logInfo(
+ s"""
+ |### Cannot convert the following logical plan back to SQL {{{
+ |${originalQuery.analyzed.treeString}
+ |Original SQL:
+ |$queryString
+ |}}}
+ """.stripMargin.trim)
+ originalQuery
+ }
+ }
+ }
+
+ (query, prepareAnswer(query, query.stringResult()))
+ } catch {
case e: Throwable =>
val errorMessage =
s"""
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index fa99289b41971..4659d745fe78b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
sql("DROP TEMPORARY FUNCTION udtf_count2")
+ super.afterAll()
}
test("SPARK-4908: concurrent hive native commands") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 579da0291f291..7f1745705aaaf 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.sources
import java.io.File
-import org.apache.spark.sql.functions._
+import org.apache.spark.sql.{AnalysisException, QueryTest}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.{AnalysisException, QueryTest}
class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index b186d297610e2..86f01d2168729 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -27,8 +27,8 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.util.Utils
import org.apache.spark.streaming.scheduler.JobGenerator
+import org.apache.spark.util.Utils
private[streaming]
class Checkpoint(ssc: StreamingContext, val checkpointTime: Time)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 4e5baebaae04b..4ccc905b275d9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -25,7 +25,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.SparkConf
-import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge}
+import org.apache.spark.serializer.{KryoInputObjectInputBridge, KryoOutputObjectOutputBridge}
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
import org.apache.spark.util.collection.OpenHashMap
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
index ea32bbf95ce59..da0430e263b5f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -17,17 +17,16 @@
package org.apache.spark.streaming
-import org.apache.spark.streaming.rdd.MapWithStateRDDRecord
-
import scala.collection.{immutable, mutable, Map}
import scala.reflect.ClassTag
import scala.util.Random
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
-import com.esotericsoftware.kryo.io.{Output, Input}
+import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer._
+import org.apache.spark.streaming.rdd.MapWithStateRDDRecord
import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap}
class StateMapSuite extends SparkFunSuite {