diff --git a/LICENSE b/LICENSE
index 9d1b00beff748..d0cd0dcb4bdb7 100644
--- a/LICENSE
+++ b/LICENSE
@@ -853,6 +853,52 @@ and
Vis.js may be distributed under either license.
+========================================================================
+For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js):
+========================================================================
+Copyright (c) 2013 Chris Pettitt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+========================================================================
+For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js):
+========================================================================
+Copyright (c) 2012-2013 Chris Pettitt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 411126a377950..f9447f6c3288d 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -19,9 +19,11 @@ exportMethods("arrange",
"count",
"describe",
"distinct",
+ "dropna",
"dtypes",
"except",
"explain",
+ "fillna",
"filter",
"first",
"group_by",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index ed8093c80d360..0af5cb8881e35 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1314,9 +1314,8 @@ setMethod("except",
#' write.df(df, "myfile", "parquet", "overwrite")
#' }
setMethod("write.df",
- signature(df = "DataFrame", path = 'character', source = 'character',
- mode = 'character'),
- function(df, path = NULL, source = NULL, mode = "append", ...){
+ signature(df = "DataFrame", path = 'character'),
+ function(df, path, source = NULL, mode = "append", ...){
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
@@ -1338,9 +1337,8 @@ setMethod("write.df",
#' @aliases saveDF
#' @export
setMethod("saveDF",
- signature(df = "DataFrame", path = 'character', source = 'character',
- mode = 'character'),
- function(df, path = NULL, source = NULL, mode = "append", ...){
+ signature(df = "DataFrame", path = 'character'),
+ function(df, path, source = NULL, mode = "append", ...){
write.df(df, path, source, mode, ...)
})
@@ -1431,3 +1429,128 @@ setMethod("describe",
sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
dataFrame(sdf)
})
+
+#' dropna
+#'
+#' Returns a new DataFrame omitting rows with null values.
+#'
+#' @param x A SparkSQL DataFrame.
+#' @param how "any" or "all".
+#' if "any", drop a row if it contains any nulls.
+#' if "all", drop a row only if all its values are null.
+#' if minNonNulls is specified, how is ignored.
+#' @param minNonNulls If specified, drop rows that have less than
+#' minNonNulls non-null values.
+#' This overwrites the how parameter.
+#' @param cols Optional list of column names to consider.
+#' @return A DataFrame
+#'
+#' @rdname nafunctions
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' dropna(df)
+#' }
+setMethod("dropna",
+ signature(x = "DataFrame"),
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ how <- match.arg(how)
+ if (is.null(cols)) {
+ cols <- columns(x)
+ }
+ if (is.null(minNonNulls)) {
+ minNonNulls <- if (how == "any") { length(cols) } else { 1 }
+ }
+
+ naFunctions <- callJMethod(x@sdf, "na")
+ sdf <- callJMethod(naFunctions, "drop",
+ as.integer(minNonNulls), listToSeq(as.list(cols)))
+ dataFrame(sdf)
+ })
+
+#' @aliases dropna
+#' @export
+setMethod("na.omit",
+ signature(x = "DataFrame"),
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ dropna(x, how, minNonNulls, cols)
+ })
+
+#' fillna
+#'
+#' Replace null values.
+#'
+#' @param x A SparkSQL DataFrame.
+#' @param value Value to replace null values with.
+#' Should be an integer, numeric, character or named list.
+#' If the value is a named list, then cols is ignored and
+#' value must be a mapping from column name (character) to
+#' replacement value. The replacement value must be an
+#' integer, numeric or character.
+#' @param cols optional list of column names to consider.
+#' Columns specified in cols that do not have matching data
+#' type are ignored. For example, if value is a character, and
+#' subset contains a non-character column, then the non-character
+#' column is simply ignored.
+#' @return A DataFrame
+#'
+#' @rdname nafunctions
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' fillna(df, 1)
+#' fillna(df, list("age" = 20, "name" = "unknown"))
+#' }
+setMethod("fillna",
+ signature(x = "DataFrame"),
+ function(x, value, cols = NULL) {
+ if (!(class(value) %in% c("integer", "numeric", "character", "list"))) {
+ stop("value should be an integer, numeric, charactor or named list.")
+ }
+
+ if (class(value) == "list") {
+ # Check column names in the named list
+ colNames <- names(value)
+ if (length(colNames) == 0 || !all(colNames != "")) {
+ stop("value should be an a named list with each name being a column name.")
+ }
+
+ # Convert to the named list to an environment to be passed to JVM
+ valueMap <- new.env()
+ for (col in colNames) {
+ # Check each item in the named list is of valid type
+ v <- value[[col]]
+ if (!(class(v) %in% c("integer", "numeric", "character"))) {
+ stop("Each item in value should be an integer, numeric or charactor.")
+ }
+ valueMap[[col]] <- v
+ }
+
+ # When value is a named list, caller is expected not to pass in cols
+ if (!is.null(cols)) {
+ warning("When value is a named list, cols is ignored!")
+ cols <- NULL
+ }
+
+ value <- valueMap
+ } else if (is.integer(value)) {
+ # Cast an integer to a numeric
+ value <- as.numeric(value)
+ }
+
+ naFunctions <- callJMethod(x@sdf, "na")
+ sdf <- if (length(cols) == 0) {
+ callJMethod(naFunctions, "fill", value)
+ } else {
+ callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
+ }
+ dataFrame(sdf)
+ })
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 36cc612875879..88e1a508f37c4 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -457,6 +457,11 @@ read.df <- function(sqlContext, path = NULL, source = NULL, ...) {
if (!is.null(path)) {
options[['path']] <- path
}
+ if (is.null(source)) {
+ sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
+ source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ }
sdf <- callJMethod(sqlContext, "load", source, options)
dataFrame(sdf)
}
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index a23d3b217b2fd..12e09176c9f92 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -396,6 +396,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") })
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
+#' @rdname nafunctions
+#' @export
+setGeneric("dropna",
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("dropna")
+ })
+
+#' @rdname nafunctions
+#' @export
+setGeneric("na.omit",
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("na.omit")
+ })
+
#' @rdname schema
#' @export
setGeneric("dtypes", function(x) { standardGeneric("dtypes") })
@@ -408,6 +422,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") })
#' @export
setGeneric("except", function(x, y) { standardGeneric("except") })
+#' @rdname nafunctions
+#' @export
+setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") })
+
#' @rdname filter
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
@@ -482,11 +500,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) {
#' @rdname write.df
#' @export
-setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") })
+setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
#' @rdname write.df
#' @export
-setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") })
+setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") })
#' @rdname schema
#' @export
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index c53d0a961016f..2081786e6f833 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -160,6 +160,14 @@ writeList <- function(con, arr) {
}
}
+# Used to pass arrays where the elements can be of different types
+writeGenericList <- function(con, list) {
+ writeInt(con, length(list))
+ for (elem in list) {
+ writeObject(con, elem)
+ }
+}
+
# Used to pass in hash maps required on Java side.
writeEnv <- function(con, env) {
len <- length(env)
@@ -168,7 +176,7 @@ writeEnv <- function(con, env) {
if (len > 0) {
writeList(con, as.list(ls(env)))
vals <- lapply(ls(env), function(x) { env[[x]] })
- writeList(con, as.list(vals))
+ writeGenericList(con, as.list(vals))
}
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 68387f0f5365d..5ced7c688f98a 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -225,14 +225,21 @@ sparkR.init <- function(
#' sqlContext <- sparkRSQL.init(sc)
#'}
-sparkRSQL.init <- function(jsc) {
+sparkRSQL.init <- function(jsc = NULL) {
if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
return(get(".sparkRSQLsc", envir = .sparkREnv))
}
+ # If jsc is NULL, create a Spark Context
+ sc <- if (is.null(jsc)) {
+ sparkR.init()
+ } else {
+ jsc
+ }
+
sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
- "createSQLContext",
- jsc)
+ "createSQLContext",
+ sc)
assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv)
sqlContext
}
@@ -249,12 +256,19 @@ sparkRSQL.init <- function(jsc) {
#' sqlContext <- sparkRHive.init(sc)
#'}
-sparkRHive.init <- function(jsc) {
+sparkRHive.init <- function(jsc = NULL) {
if (exists(".sparkRHivesc", envir = .sparkREnv)) {
return(get(".sparkRHivesc", envir = .sparkREnv))
}
- ssc <- callJMethod(jsc, "sc")
+ # If jsc is NULL, create a Spark Context
+ sc <- if (is.null(jsc)) {
+ sparkR.init()
+ } else {
+ jsc
+ }
+
+ ssc <- callJMethod(sc, "sc")
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
}, error = function(err) {
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 1857e636e8577..d2d82e791e876 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -32,6 +32,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet")
writeLines(mockLines, jsonPath)
+# For test nafunctions, like dropna(), fillna(),...
+mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}",
+ "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}",
+ "{\"name\":\"David\",\"age\":60,\"height\":null}",
+ "{\"name\":\"Amy\",\"age\":null,\"height\":null}",
+ "{\"name\":null,\"age\":null,\"height\":null}")
+jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp")
+writeLines(mockLinesNa, jsonPathNa)
+
test_that("infer types", {
expect_equal(infer_type(1L), "integer")
expect_equal(infer_type(1.0), "double")
@@ -765,5 +774,105 @@ test_that("describe() on a DataFrame", {
expect_equal(collect(stats)[5, "age"], "30")
})
+test_that("dropna() on a DataFrame", {
+ df <- jsonFile(sqlContext, jsonPathNa)
+ rows <- collect(df)
+
+ # drop with columns
+
+ expected <- rows[!is.na(rows$name),]
+ actual <- collect(dropna(df, cols = "name"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age),]
+ actual <- collect(dropna(df, cols = "age"))
+ row.names(expected) <- row.names(actual)
+ # identical on two dataframes does not work here. Don't know why.
+ # use identical on all columns as a workaround.
+ expect_true(identical(expected$age, actual$age))
+ expect_true(identical(expected$height, actual$height))
+ expect_true(identical(expected$name, actual$name))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
+ actual <- collect(dropna(df, cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df))
+ expect_true(identical(expected, actual))
+
+ # drop with how
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),]
+ actual <- collect(dropna(df, "all"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df, "any"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
+ actual <- collect(dropna(df, "any", cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) | !is.na(rows$height),]
+ actual <- collect(dropna(df, "all", cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ # drop with threshold
+
+ expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,]
+ actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[as.integer(!is.na(rows$age)) +
+ as.integer(!is.na(rows$height)) +
+ as.integer(!is.na(rows$name)) >= 3,]
+ actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height")))
+ expect_true(identical(expected, actual))
+})
+
+test_that("fillna() on a DataFrame", {
+ df <- jsonFile(sqlContext, jsonPathNa)
+ rows <- collect(df)
+
+ # fill with value
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ expected$height[is.na(expected$height)] <- 50.6
+ actual <- collect(fillna(df, 50.6))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, "unknown"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ actual <- collect(fillna(df, 50.6, "age"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, "unknown", c("age", "name")))
+ expect_true(identical(expected, actual))
+
+ # fill with named list
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ expected$height[is.na(expected$height)] <- 50.6
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown")))
+ expect_true(identical(expected, actual))
+})
+
unlink(parquetPath)
unlink(jsonPath)
+unlink(jsonPathNa)
diff --git a/README.md b/README.md
index 9c09d40e2bdae..380422ca00dbe 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,8 @@
Spark is a fast and general cluster computing system for Big Data. It provides
high-level APIs in Scala, Java, and Python, and an optimized engine that
supports general computation graphs for data analysis. It also supports a
-rich set of higher-level tools including Spark SQL for SQL and structured
-data processing, MLlib for machine learning, GraphX for graph processing,
+rich set of higher-level tools including Spark SQL for SQL and DataFrames,
+MLlib for machine learning, GraphX for graph processing,
and Spark Streaming for stream processing.
@@ -22,7 +22,7 @@ This README file only contains basic setup instructions.
Spark is built using [Apache Maven](http://maven.apache.org/).
To build Spark and its example programs, run:
- mvn -DskipTests clean package
+ build/mvn -DskipTests clean package
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
@@ -43,7 +43,7 @@ Try the following command, which should return 1000:
Alternatively, if you prefer Python, you can use the Python shell:
./bin/pyspark
-
+
And run the following command, which should also return 1000:
>>> sc.parallelize(range(1000)).count()
@@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example:
will run the Pi example locally.
You can set the MASTER environment variable when running examples to submit
-examples to a cluster. This can be a mesos:// or spark:// URL,
-"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run
-locally with one thread, or "local[N]" to run locally with N threads. You
+examples to a cluster. This can be a mesos:// or spark:// URL,
+"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run
+locally with one thread, or "local[N]" to run locally with N threads. You
can also use an abbreviated class name if the class is in the `examples`
package. For instance:
@@ -75,7 +75,7 @@ can be run using:
./dev/run-tests
-Please see the guidance on how to
+Please see the guidance on how to
[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools).
## A Note About Hadoop Versions
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 1f3dec91314f2..132cd433d78a2 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.scalacheckscalacheck_${scala.binary.version}
diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
index ccb262a4ee02a..fb10d734ac74b 100644
--- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.bagel
-import org.scalatest.{BeforeAndAfter, FunSuite, Assertions}
+import org.scalatest.{BeforeAndAfter, Assertions}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
@@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
-class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
+class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts {
var sc: SparkContext = _
diff --git a/bin/pyspark b/bin/pyspark
index 8acad6113797d..7cb19c51b43a2 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -90,11 +90,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
- if [[ -n "$PYSPARK_DOC_TEST" ]]; then
- exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
- else
- exec "$PYSPARK_DRIVER_PYTHON" $1
- fi
+ exec "$PYSPARK_DRIVER_PYTHON" -m $1
exit
fi
diff --git a/core/pom.xml b/core/pom.xml
index e58efe495e36d..a02184222e9f0 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -338,6 +338,12 @@
org.seleniumhq.seleniumselenium-java
+
+
+ com.google.guava
+ guava
+
+ test
@@ -475,6 +481,29 @@
+
+ sparkr-docs
+
+
+
+ org.codehaus.mojo
+ exec-maven-plugin
+
+
+ sparkr-pkg-docs
+ compile
+
+ exec
+
+
+
+
+ ..${path.separator}R${path.separator}create-docs${script.extension}
+
+
+
+
+
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
new file mode 100644
index 0000000000000..d3d6280284beb
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
+import org.apache.spark.util.Utils;
+
+/**
+ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
+ * writes incoming records to separate files, one file per reduce partition, then concatenates these
+ * per-partition files to form a single output file, regions of which are served to reducers.
+ * Records are not buffered in memory. This is essentially identical to
+ * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format
+ * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}.
+ *
+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it
+ * simultaneously opens separate serializers and file streams for all partitions. As a result,
+ * {@link SortShuffleManager} only selects this write path when
+ *
+ *
no Ordering is specified,
+ *
no Aggregator is specific, and
+ *
the number of partitions is less than
+ * spark.shuffle.sort.bypassMergeThreshold.
+ *
+ *
+ * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was
+ * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
+ *
+ * There have been proposals to completely remove this code path; see SPARK-6026 for details.
+ */
+final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter {
+
+ private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
+
+ private final int fileBufferSize;
+ private final boolean transferToEnabled;
+ private final int numPartitions;
+ private final BlockManager blockManager;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final Serializer serializer;
+
+ /** Array of file writers, one for each partition */
+ private BlockObjectWriter[] partitionWriters;
+
+ public BypassMergeSortShuffleWriter(
+ SparkConf conf,
+ BlockManager blockManager,
+ Partitioner partitioner,
+ ShuffleWriteMetrics writeMetrics,
+ Serializer serializer) {
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+ this.numPartitions = partitioner.numPartitions();
+ this.blockManager = blockManager;
+ this.partitioner = partitioner;
+ this.writeMetrics = writeMetrics;
+ this.serializer = serializer;
+ }
+
+ @Override
+ public void insertAll(Iterator> records) throws IOException {
+ assert (partitionWriters == null);
+ if (!records.hasNext()) {
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new BlockObjectWriter[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2 tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open();
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime);
+
+ while (records.hasNext()) {
+ final Product2 record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+ }
+
+ for (BlockObjectWriter writer : partitionWriters) {
+ writer.commitAndClose();
+ }
+ }
+
+ @Override
+ public long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException {
+ // Track location of the partition starts in the output file
+ final long[] lengths = new long[numPartitions];
+ if (partitionWriters == null) {
+ // We were passed an empty iterator
+ return lengths;
+ }
+
+ final FileOutputStream out = new FileOutputStream(outputFile, true);
+ final long writeStartTime = System.nanoTime();
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < numPartitions; i++) {
+ final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file());
+ boolean copyThrewException = true;
+ try {
+ lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) {
+ logger.error("Unable to delete file for partition {}", i);
+ }
+ }
+ threwException = false;
+ } finally {
+ Closeables.close(out, threwException);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ }
+ partitionWriters = null;
+ return lengths;
+ }
+
+ @Override
+ public void stop() throws IOException {
+ if (partitionWriters != null) {
+ try {
+ final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
+ for (BlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ writer.revertPartialWritesAndClose();
+ if (!diskBlockManager.getFile(writer.blockId()).delete()) {
+ logger.error("Error while deleting file for block {}", writer.blockId());
+ }
+ }
+ } finally {
+ partitionWriters = null;
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
new file mode 100644
index 0000000000000..656ea0401a144
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.collection.Iterator;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.TaskContext;
+import org.apache.spark.storage.BlockId;
+
+/**
+ * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
+ */
+@Private
+public interface SortShuffleFileWriter {
+
+ void insertAll(Iterator> records) throws IOException;
+
+ /**
+ * Write all the data added into this shuffle sorter into a file in the disk store. This is
+ * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+ * binary files if we decided to avoid merge-sorting.
+ *
+ * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+ * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+ */
+ long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException;
+
+ void stop() throws IOException;
+}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index b8a5f5016860f..ceeb58075d345 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -34,8 +34,8 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
- // When spilling is enabled sorting will happen externally, but not necessarily with an
- // ExternalSorter.
+ // When spilling is enabled sorting will happen externally, but not necessarily with an
+ // ExternalSorter.
private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
@deprecated("use combineValuesByKey with TaskContext argument", "0.9.0")
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 91f9ef8ce7185..48792a958130c 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
override def isCompleted: Boolean = jobWaiter.jobFinished
-
+
override def isCancelled: Boolean = _cancelled
override def value: Option[Try[T]] = {
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index f2b024ff6cb67..6909015ff66e6 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -29,7 +29,7 @@ import org.apache.spark.util.{ThreadUtils, Utils}
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
- * components to convey liveness or execution information for in-progress tasks. It will also
+ * components to convey liveness or execution information for in-progress tasks. It will also
* expire the hosts that have not heartbeated for more than spark.network.timeout.
*/
private[spark] case class Heartbeat(
@@ -43,8 +43,8 @@ private[spark] case class Heartbeat(
*/
private[spark] case object TaskSchedulerIsSet
-private[spark] case object ExpireDeadHosts
-
+private[spark] case object ExpireDeadHosts
+
private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
/**
@@ -62,18 +62,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
// "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses
// "milliseconds"
- private val slaveTimeoutMs =
+ private val slaveTimeoutMs =
sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s")
- private val executorTimeoutMs =
+ private val executorTimeoutMs =
sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000
-
+
// "spark.network.timeoutInterval" uses "seconds", while
// "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds"
- private val timeoutIntervalMs =
+ private val timeoutIntervalMs =
sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s")
- private val checkTimeoutIntervalMs =
+ private val checkTimeoutIntervalMs =
sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000
-
+
private var timeoutCheckingTask: ScheduledFuture[_] = null
// "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not
@@ -140,7 +140,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
}
}
}
-
+
override def onStop(): Unit = {
if (timeoutCheckingTask != null) {
timeoutCheckingTask.cancel(true)
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index 7e706bcc42f04..7cf7bc0dc6810 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -50,8 +50,8 @@ private[spark] class HttpFileServer(
def stop() {
httpServer.stop()
-
- // If we only stop sc, but the driver process still run as a services then we need to delete
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
// the tmp dir, if not, it will create too many tmp dirs
try {
Utils.deleteRecursively(baseDir)
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 4b5bcb54aa873..46d72841dccce 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsBytes(key: String, defaultValue: String): Long = {
Utils.byteStringAsBytes(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Kibibytes are assumed.
@@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsKb(key: String, defaultValue: String): Long = {
Utils.byteStringAsKb(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Mebibytes are assumed.
@@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsMb(key: String, defaultValue: String): Long = {
Utils.byteStringAsMb(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Gibibytes are assumed.
@@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsGb(key: String, defaultValue: String): Long = {
Utils.byteStringAsGb(get(key, defaultValue))
}
-
+
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
Option(settings.get(key)).orElse(getDeprecatedConfig(key, this))
@@ -480,7 +480,7 @@ private[spark] object SparkConf extends Logging {
"spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " +
"are no longer accepted. To specify the equivalent now, one may use '64k'.")
)
-
+
Map(configs.map { cfg => (cfg.key -> cfg) } : _*)
}
@@ -508,7 +508,7 @@ private[spark] object SparkConf extends Logging {
"spark.reducer.maxSizeInFlight" -> Seq(
AlternateConfig("spark.reducer.maxMbInFlight", "1.4")),
"spark.kryoserializer.buffer" ->
- Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4",
+ Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4",
translation = s => s"${(s.toDouble * 1000).toInt}k")),
"spark.kryoserializer.buffer.max" -> Seq(
AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")),
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index fe6320b504e15..a1ebbecf93b7b 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -51,7 +51,7 @@ private[spark] object TestUtils {
classpathUrls: Seq[URL] = Seq()): URL = {
val tempDir = Utils.createTempDir()
val files1 = for (name <- classNames) yield {
- createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
+ createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
}
val files2 = for ((childName, baseName) <- classNamesWithBase) yield {
createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 61af867b11b9c..a650df605b92e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double])
*/
def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
*/
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index db4e996feb31c..ed312770ee131 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
/**
* Return a sampled subset of this RDD.
- *
+ *
* @param withReplacement can elements be sampled multiple times (replaced when sampled out)
* @param fraction expected size of the sample as a fraction of this RDD's size
* without replacement: probability that each element is chosen; fraction must be [0, 1]
@@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
*/
def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
- *
+ *
* @param withReplacement can elements be sampled multiple times (replaced when sampled out)
* @param fraction expected size of the sample as a fraction of this RDD's size
* without replacement: probability that each element is chosen; fraction must be [0, 1]
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index b8e15f38a20d2..c95615a5a9307 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
@deprecated("Use partitions() instead.", "1.1.0")
def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
-
+
/** Set of partitions in this RDD. */
def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
+ /** The partitioner of this RDD. */
+ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner)
+
/** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
@@ -492,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
- def takeSample(withReplacement: Boolean, num: Int): JList[T] =
+ def takeSample(withReplacement: Boolean, num: Int): JList[T] =
takeSample(withReplacement, num, Utils.random.nextLong)
-
+
def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index a77bf42ce1d38..55a37f8c944b2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -797,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
- /**
+ /**
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded actor anyway.
- */
+ */
@transient var socket: Socket = _
def openSocket(): Socket = synchronized {
@@ -843,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
* An Wrapper for Python Broadcast, which is written into disk by Python. It also will
* write the data into disk after deserialization, then Python can read it from disks.
*/
+// scalastyle:off no.finalize
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable {
/**
@@ -884,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
}
}
+// scalastyle:on no.finalize
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 0a91977928cee..d24c650d37bb0 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -44,11 +44,11 @@ private[spark] class RBackend {
bossGroup = new NioEventLoopGroup(2)
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
-
+
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(classOf[NioServerSocketChannel])
-
+
bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
def initChannel(ch: SocketChannel): Unit = {
ch.pipeline()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 026a1b9380357..2e86984c66b3a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend)
val reply = bos.toByteArray
ctx.write(reply)
}
-
+
override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
ctx.flush()
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index e020458888e4a..4dfa7325934ff 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -355,7 +355,6 @@ private[r] object RRDD {
val sparkConf = new SparkConf().setAppName(appName)
.setSparkHome(sparkHome)
- .setJars(jars)
// Override `master` if we have a user-specified value
if (master != "") {
@@ -373,7 +372,11 @@ private[r] object RRDD {
sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String])
}
- new JavaSparkContext(sparkConf)
+ val jsc = new JavaSparkContext(sparkConf)
+ jars.foreach { jar =>
+ jsc.addJar(jar)
+ }
+ jsc
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 371dfe454d1a2..f8e3f1a79082e 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -157,9 +157,11 @@ private[spark] object SerDe {
val keysLen = readInt(in)
val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
- val valuesType = readObjectType(in)
val valuesLen = readInt(in)
- val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType))
+ val values = (0 until valuesLen).map(_ => {
+ val valueType = readObjectType(in)
+ readTypedObject(in, valueType)
+ })
mapAsJavaMap(keys.zip(values).toMap)
} else {
new java.util.HashMap[Object, Object]()
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 92bb5059a0313..8cf4d58847d8e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -428,6 +428,8 @@ object SparkSubmit {
OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"),
OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"),
OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"),
+ OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"),
+ OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"),
// Yarn cluster only
OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
@@ -440,10 +442,8 @@ object SparkSubmit {
OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"),
OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"),
-
- // Yarn client or cluster
- OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"),
- OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"),
+ OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"),
+ OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"),
// Other options
OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES,
@@ -869,7 +869,7 @@ private[spark] object SparkSubmitUtils {
md.addDependency(dd)
}
}
-
+
/** Add exclusion rules for dependencies already included in the spark-assembly */
def addExclusionRules(
ivySettings: IvySettings,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index c0e4c771908b3..cc6a7bd9f4119 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -169,6 +169,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
numExecutors = Option(numExecutors)
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
+ keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull
+ principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && !isR && primaryResource != null) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index a2a97a7877ce7..4692d22651c93 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -23,7 +23,7 @@ import org.apache.spark.util.Utils
/**
* Command-line parser for the master.
*/
-private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
+private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
extends Logging {
private var propertiesFile: String = null
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 80db6d474b5c1..328d95a7a0c68 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil
private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
extends PersistenceEngine
with Logging {
-
+
private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 756927682cd24..6a7c74020bace 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -75,6 +75,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory")
val workers = state.workers.sortBy(_.id)
+ val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
@@ -108,12 +109,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}.getOrElse { Seq.empty }
}
-
Workers: {state.workers.size}
-
Cores: {state.workers.map(_.cores).sum} Total,
- {state.workers.map(_.coresUsed).sum} Used
-
Memory:
- {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total,
- {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
+
Alive Workers: {aliveWorkers.size}
+
Cores in use: {aliveWorkers.map(_.cores).sum} Total,
+ {aliveWorkers.map(_.coresUsed).sum} Used
+
Memory in use:
+ {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total,
+ {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
diff --git a/docs/building-spark.md b/docs/building-spark.md
index b2649d1ee2a53..78cb9086f95e8 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -80,6 +80,7 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro
2.2.x
hadoop-2.2
2.3.x
hadoop-2.3
2.4.x
hadoop-2.4
+
2.6.x and later 2.x
hadoop-2.6
@@ -130,9 +131,7 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop
dev/change-version-to-2.11.sh
mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package
-Scala 2.11 support in Spark does not support a few features due to dependencies
-which are themselves not Scala 2.11 ready. Specifically, Spark's external
-Kafka library and JDBC component are not yet supported in Scala 2.11 builds.
+Spark does not yet support its JDBC component for Scala 2.11.
# Spark Tests in Maven
diff --git a/docs/configuration.md b/docs/configuration.md
index 30508a617fdd8..3a48da4592dd9 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1,4 +1,4 @@
---
+---
layout: global
displayTitle: Spark Configuration
title: Configuration
diff --git a/docs/index.md b/docs/index.md
index 5ef6d983c45a5..fac071da81e60 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -54,7 +54,7 @@ Example applications are also provided in Python. For example,
./bin/spark-submit examples/src/main/python/pi.py 10
-Spark also provides an experimental R API since 1.4 (only DataFrames APIs included).
+Spark also provides an experimental [R API](sparkr.html) since 1.4 (only DataFrames APIs included).
To run Spark interactively in a R interpreter, use `bin/sparkR`:
./bin/sparkR --master local[2]
diff --git a/docs/ml-features.md b/docs/ml-features.md
index d7851a55fabfe..f88c0248c1a8a 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3):
+## StringIndexer
+
+`StringIndexer` encodes a string column of labels to a column of label indices.
+The indices are in `[0, numLabels)`, ordered by label frequencies.
+So the most frequent label gets index `0`.
+If the input column is numeric, we cast it to string and index the string values.
+
+**Examples**
+
+Assume that we have the following DataFrame with columns `id` and `category`:
+
+~~~~
+ id | category
+----|----------
+ 0 | a
+ 1 | b
+ 2 | c
+ 3 | a
+ 4 | a
+ 5 | c
+~~~~
+
+`category` is a string column with three labels: "a", "b", and "c".
+Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output
+column, we should get the following:
+
+~~~~
+ id | category | categoryIndex
+----|----------|---------------
+ 0 | a | 0.0
+ 1 | b | 2.0
+ 2 | c | 1.0
+ 3 | a | 0.0
+ 4 | a | 0.0
+ 5 | c | 1.0
+~~~~
+
+"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
+index `2`.
+
+
+
## OneHotEncoder
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
@@ -876,5 +992,207 @@ bucketedData = bucketizer.transform(dataFrame)
+## ElementwiseProduct
+
+ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector.
+
+`\[ \begin{pmatrix}
+v_1 \\
+\vdots \\
+v_N
+\end{pmatrix} \circ \begin{pmatrix}
+ w_1 \\
+ \vdots \\
+ w_N
+ \end{pmatrix}
+= \begin{pmatrix}
+ v_1 w_1 \\
+ \vdots \\
+ v_N w_N
+ \end{pmatrix}
+\]`
+
+[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) takes the following parameter:
+
+* `scalingVec`: the transforming vector.
+
+This example below demonstrates how to transform vectors using a transforming vector value.
+
+
+
+{% highlight scala %}
+import org.apache.spark.ml.feature.ElementwiseProduct
+import org.apache.spark.mllib.linalg.Vectors
+
+// Create some vector data; also works for sparse vectors
+val dataFrame = sqlContext.createDataFrame(Seq(
+ ("a", Vectors.dense(1.0, 2.0, 3.0)),
+ ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector")
+
+val transformingVector = Vectors.dense(0.0, 1.0, 2.0)
+val transformer = new ElementwiseProduct()
+ .setScalingVec(transformingVector)
+ .setInputCol("vector")
+ .setOutputCol("transformedVector")
+
+// Batch transform the vectors to create new column:
+val transformedData = transformer.transform(dataFrame)
+
+{% endhighlight %}
+
+
+## VectorAssembler
+
+`VectorAssembler` is a transformer that combines a given list of columns into a single vector
+column.
+It is useful for combining raw features and features generated by different feature transformers
+into a single feature vector, in order to train ML models like logistic regression and decision
+trees.
+`VectorAssembler` accepts the following input column types: all numeric types, boolean type,
+and vector type.
+In each row, the values of the input columns will be concatenated into a vector in the specified
+order.
+
+**Examples**
+
+Assume that we have a DataFrame with the columns `id`, `hour`, `mobile`, `userFeatures`,
+and `clicked`:
+
+~~~
+ id | hour | mobile | userFeatures | clicked
+----|------|--------|------------------|---------
+ 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0
+~~~
+
+`userFeatures` is a vector column that contains three user features.
+We want to combine `hour`, `mobile`, and `userFeatures` into a single feature vector
+called `features` and use it to predict `clicked` or not.
+If we set `VectorAssembler`'s input columns to `hour`, `mobile`, and `userFeatures` and
+output column to `features`, after transformation we should get the following DataFrame:
+
+~~~
+ id | hour | mobile | userFeatures | clicked | features
+----|------|--------|------------------|---------|-----------------------------
+ 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 | [18.0, 1.0, 0.0, 10.0, 0.5]
+~~~
+
+
+
# Feature Selectors
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index c5f50ed7990f1..4eb622d4b95e8 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -207,7 +207,7 @@ val model1 = lr.fit(training.toDF)
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
-println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+println("Model 1 was fit using parameters: " + model1.parent.extractParamMap)
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -222,7 +222,7 @@ val paramMapCombined = paramMap ++ paramMap2
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF, paramMapCombined)
-println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+println("Model 2 was fit using parameters: " + model2.parent.extractParamMap)
// Prepare test data.
val test = sc.parallelize(Seq(
@@ -289,7 +289,7 @@ LogisticRegressionModel model1 = lr.fit(training);
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
-System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
@@ -305,7 +305,7 @@ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
-System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
List localTest = Lists.newArrayList(
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index f41ca70952eb7..dac22f736e8cb 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -47,7 +47,7 @@ Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasin
optimal *k* is usually one where there is an "elbow" in the WSSSE graph.
{% highlight scala %}
-import org.apache.spark.mllib.clustering.KMeans
+import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -62,6 +62,10 @@ val clusters = KMeans.train(parsedData, numClusters, numIterations)
// Evaluate clustering by computing Within Set Sum of Squared Errors
val WSSSE = clusters.computeCost(parsedData)
println("Within Set Sum of Squared Errors = " + WSSSE)
+
+// Save and load model
+clusters.save(sc, "myModelPath")
+val sameModel = KMeansModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -110,6 +114,10 @@ public class KMeansExample {
// Evaluate clustering by computing Within Set Sum of Squared Errors
double WSSSE = clusters.computeCost(parsedData.rdd());
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
+
+ // Save and load model
+ clusters.save(sc.sc(), "myModelPath");
+ KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -124,7 +132,7 @@ Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by in
fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph.
{% highlight python %}
-from pyspark.mllib.clustering import KMeans
+from pyspark.mllib.clustering import KMeans, KMeansModel
from numpy import array
from math import sqrt
@@ -143,6 +151,10 @@ def error(point):
WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y)
print("Within Set Sum of Squared Error = " + str(WSSSE))
+
+# Save and load model
+clusters.save(sc, "myModelPath")
+sameModel = KMeansModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -312,12 +324,12 @@ Calling `PowerIterationClustering.run` returns a
which contains the computed clustering assignments.
{% highlight scala %}
-import org.apache.spark.mllib.clustering.PowerIterationClustering
+import org.apache.spark.mllib.clustering.{PowerIterationClustering, PowerIterationClusteringModel}
import org.apache.spark.mllib.linalg.Vectors
val similarities: RDD[(Long, Long, Double)] = ...
-val pic = new PowerIteartionClustering()
+val pic = new PowerIterationClustering()
.setK(3)
.setMaxIterations(20)
val model = pic.run(similarities)
@@ -325,6 +337,10 @@ val model = pic.run(similarities)
model.assignments.foreach { a =>
println(s"${a.id} -> ${a.cluster}")
}
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath")
{% endhighlight %}
A full example that produces the experiment described in the PIC paper can be found under
@@ -360,6 +376,10 @@ PowerIterationClusteringModel model = pic.run(similarities);
for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) {
System.out.println(a.id() + " -> " + a.cluster());
}
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index f723cd6b9dfab..4fe470a8de810 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th
import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.feature.Word2Vec
+import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
val input = sc.textFile("text8").map(line => line.split(" ").toSeq)
@@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40)
for((synonym, cosineSimilarity) <- synonyms) {
println(s"$synonym $cosineSimilarity")
}
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = Word2VecModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -410,6 +414,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.feature.ChiSqSelector
// Load some data in libsvm format
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
@@ -505,7 +510,7 @@ v_N
### Example
-This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value.
+This example below demonstrates how to transform vectors using a transforming vector value.
@@ -514,16 +519,44 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.feature.ElementwiseProduct
import org.apache.spark.mllib.linalg.Vectors
-// Load and parse the data:
-val data = sc.textFile("data/mllib/kmeans_data.txt")
-val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
+// Create some vector data; also works for sparse vectors
+val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0)))
val transformingVector = Vectors.dense(0.0, 1.0, 2.0)
val transformer = new ElementwiseProduct(transformingVector)
// Batch transform and per-row transform give the same results:
-val transformedData = transformer.transform(parsedData)
-val transformedData2 = parsedData.map(x => transformer.transform(x))
+val transformedData = transformer.transform(data)
+val transformedData2 = data.map(x => transformer.transform(x))
+
+{% endhighlight %}
+
+
+
+{% highlight java %}
+import java.util.Arrays;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.feature.ElementwiseProduct;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+// Create some vector data; also works for sparse vectors
+JavaRDD data = sc.parallelize(Arrays.asList(
+ Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0)));
+Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0);
+ElementwiseProduct transformer = new ElementwiseProduct(transformingVector);
+
+// Batch transform and per-row transform give the same results:
+JavaRDD transformedData = transformer.transform(data);
+JavaRDD transformedData2 = data.map(
+ new Function() {
+ @Override
+ public Vector call(Vector v) {
+ return transformer.transform(v);
+ }
+ }
+);
{% endhighlight %}
diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md
index b521c2f27cd6e..5732bc4c7e79e 100644
--- a/docs/mllib-isotonic-regression.md
+++ b/docs/mllib-isotonic-regression.md
@@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f
labels and real labels in the test set.
{% highlight scala %}
-import org.apache.spark.mllib.regression.IsotonicRegression
+import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel}
val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt")
@@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point =>
// Calculate mean squared error between predicted and real labels.
val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean()
println("Mean Squared Error = " + meanSquaredError)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = IsotonicRegressionModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map(
).rdd()).mean();
System.out.println("Mean Squared Error = " + meanSquaredError);
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 9d55f435e80ad..96cf612c54fdd 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -242,6 +242,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
running against earlier versions, this property will be ignored.
+
+
spark.yarn.keytab
+
(none)
+
+ The full path to the file that contains the keytab for the principal specified above.
+ This keytab will be copied to the node running the Application Master via the Secure Distributed Cache,
+ for renewing the login tickets and the delegation tokens periodically.
+
+
+
+
spark.yarn.principal
+
(none)
+
+ Principal to be used to login to KDC, while running on secure HDFS.
+
+
# Launching Spark on YARN
diff --git a/docs/sparkr.md b/docs/sparkr.md
new file mode 100644
index 0000000000000..4d82129921a37
--- /dev/null
+++ b/docs/sparkr.md
@@ -0,0 +1,223 @@
+---
+layout: global
+displayTitle: SparkR (R on Spark)
+title: SparkR (R on Spark)
+---
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+# Overview
+SparkR is an R package that provides a light-weight frontend to use Apache Spark from R.
+In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that
+supports operations like selection, filtering, aggregation etc. (similar to R data frames,
+[dplyr](https://github.com/hadley/dplyr)) but on large datasets.
+
+# SparkR DataFrames
+
+A DataFrame is a distributed collection of data organized into named columns. It is conceptually
+equivalent to a table in a relational database or a data frame in R, but with richer
+optimizations under the hood. DataFrames can be constructed from a wide array of sources such as:
+structured data files, tables in Hive, external databases, or existing local R data frames.
+
+All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell.
+
+## Starting Up: SparkContext, SQLContext
+
+
+The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster.
+You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name
+etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the
+SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should
+already be created for you.
+
+{% highlight r %}
+sc <- sparkR.init()
+sqlContext <- sparkRSQL.init(sc)
+{% endhighlight %}
+
+
+
+## Creating DataFrames
+With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources).
+
+### From local data frames
+The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R.
+
+
+{% highlight r %}
+df <- createDataFrame(sqlContext, faithful)
+
+# Displays the content of the DataFrame to stdout
+head(df)
+## eruptions waiting
+##1 3.600 79
+##2 1.800 54
+##3 3.333 74
+
+{% endhighlight %}
+
+
+### From Data Sources
+
+SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources.
+
+The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro).
+
+We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail.
+
+
+
+{% highlight r %}
+people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json")
+head(people)
+## age name
+##1 NA Michael
+##2 30 Andy
+##3 19 Justin
+
+# SparkR automatically infers the schema from the JSON file
+printSchema(people)
+# root
+# |-- age: integer (nullable = true)
+# |-- name: string (nullable = true)
+
+{% endhighlight %}
+
+
+The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example
+to a Parquet file using `write.df`
+
+
+
+### From Hive tables
+
+You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext).
+
+
+{% highlight r %}
+# sc is an existing SparkContext.
+hiveContext <- sparkRHive.init(sc)
+
+sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+
+# Queries can be expressed in HiveQL.
+results <- hiveContext.sql("FROM src SELECT key, value")
+
+# results is now a DataFrame
+head(results)
+## key value
+## 1 238 val_238
+## 2 86 val_86
+## 3 311 val_311
+
+{% endhighlight %}
+
+
+## DataFrame Operations
+
+SparkR DataFrames support a number of functions to do structured data processing.
+Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs:
+
+### Selecting rows, columns
+
+
+{% highlight r %}
+# Create the DataFrame
+df <- createDataFrame(sqlContext, faithful)
+
+# Get basic information about the DataFrame
+df
+## DataFrame[eruptions:double, waiting:double]
+
+# Select only the "eruptions" column
+head(select(df, df$eruptions))
+## eruptions
+##1 3.600
+##2 1.800
+##3 3.333
+
+# You can also pass in column name as strings
+head(select(df, "eruptions"))
+
+# Filter the DataFrame to only retain rows with wait times shorter than 50 mins
+head(filter(df, df$waiting < 50))
+## eruptions waiting
+##1 1.750 47
+##2 1.750 47
+##3 1.867 48
+
+{% endhighlight %}
+
+
+
+### Grouping, Aggregation
+
+SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
+
+
+{% highlight r %}
+
+# We use the `n` operator to count the number of times each waiting time appears
+head(summarize(groupBy(df, df$waiting), count = n(df$waiting)))
+## waiting count
+##1 81 13
+##2 60 6
+##3 68 1
+
+# We can also sort the output from the aggregation to get the most common waiting times
+waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting))
+head(arrange(waiting_counts, desc(waiting_counts$count)))
+
+## waiting count
+##1 78 15
+##2 83 14
+##3 81 13
+
+{% endhighlight %}
+
+
+### Operating on Columns
+
+SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions.
+
+
+{% highlight r %}
+
+# Convert waiting time from hours to seconds.
+# Note that we can assign this to a new column in the same DataFrame
+df$waiting_secs <- df$waiting * 60
+head(df)
+## eruptions waiting waiting_secs
+##1 3.600 79 4740
+##2 1.800 54 3240
+##3 3.333 74 4440
+
+{% endhighlight %}
+
+
+## Running SQL Queries from SparkR
+A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data.
+The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`.
+
+
+{% highlight r %}
+# Load a JSON file
+people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json")
+
+# Register this DataFrame as a table.
+registerTempTable(people, "people")
+
+# SQL statements can be run by using the sql method
+teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19")
+head(teenagers)
+## name
+##1 Justin
+
+{% endhighlight %}
+
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index ab646f65bb5eb..282ea75e1e785 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -11,6 +11,7 @@ title: Spark SQL and DataFrames
Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine.
+For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section.
# DataFrames
@@ -108,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO
val sc: SparkContext // An existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-val df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+val df = sqlContext.read.json("examples/src/main/resources/people.json")
// Displays the content of the DataFrame to stdout
df.show()
@@ -121,7 +122,7 @@ df.show()
JavaSparkContext sc = ...; // An existing JavaSparkContext.
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
-DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json");
+DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json");
// Displays the content of the DataFrame to stdout
df.show();
@@ -134,7 +135,7 @@ df.show();
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
-df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+df = sqlContext.read.json("examples/src/main/resources/people.json")
# Displays the content of the DataFrame to stdout
df.show()
@@ -170,7 +171,7 @@ val sc: SparkContext // An existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
// Create the DataFrame
-val df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+val df = sqlContext.read.json("examples/src/main/resources/people.json")
// Show the content of the DataFrame
df.show()
@@ -220,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext.
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc)
// Create the DataFrame
-DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json");
+DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json");
// Show the content of the DataFrame
df.show();
@@ -276,7 +277,7 @@ from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
# Create the DataFrame
-df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+df = sqlContext.read.json("examples/src/main/resources/people.json")
# Show the content of the DataFrame
df.show()
@@ -776,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
Ignore mode means that when saving a DataFrame to a data source, if data already exists,
the save operation is expected to not save the contents of the DataFrame and to not
- change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL.
+ change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
@@ -946,11 +947,11 @@ import sqlContext.implicits._
val people: RDD[Person] = ... // An RDD of case class objects, from the previous example.
// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet.
-people.saveAsParquetFile("people.parquet")
+people.write.parquet("people.parquet")
// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a Parquet file is also a DataFrame.
-val parquetFile = sqlContext.parquetFile("people.parquet")
+val parquetFile = sqlContext.read.parquet("people.parquet")
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile")
@@ -968,11 +969,11 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
DataFrame schemaPeople = ... // The DataFrame from the previous example.
// DataFrames can be saved as Parquet files, maintaining the schema information.
-schemaPeople.saveAsParquetFile("people.parquet");
+schemaPeople.write().parquet("people.parquet");
// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a parquet file is also a DataFrame.
-DataFrame parquetFile = sqlContext.parquetFile("people.parquet");
+DataFrame parquetFile = sqlContext.read().parquet("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
@@ -994,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
schemaPeople # The DataFrame from the previous example.
# DataFrames can be saved as Parquet files, maintaining the schema information.
-schemaPeople.saveAsParquetFile("people.parquet")
+schemaPeople.read.parquet("people.parquet")
# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
# The result of loading a parquet file is also a DataFrame.
-parquetFile = sqlContext.parquetFile("people.parquet")
+parquetFile = sqlContext.write.parquet("people.parquet")
# Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
@@ -1030,7 +1031,7 @@ teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND a
teenNames <- map(teenagers, function(p) { paste("Name:", p$name)})
for (teenName in collect(teenNames)) {
cat(teenName, "\n")
-}
+}
{% endhighlight %}
@@ -1086,9 +1087,9 @@ path
{% endhighlight %}
-By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will
-automatically extract the partitioning information from the paths. Now the schema of the returned
-DataFrame becomes:
+By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL
+will automatically extract the partitioning information from the paths.
+Now the schema of the returned DataFrame becomes:
{% highlight text %}
@@ -1121,15 +1122,15 @@ import sqlContext.implicits._
// Create a simple DataFrame, stored into a partition directory
val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double")
-df1.saveAsParquetFile("data/test_table/key=1")
+df1.write.parquet("data/test_table/key=1")
// Create another DataFrame in a new partition directory,
// adding a new column and dropping an existing column
val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple")
-df2.saveAsParquetFile("data/test_table/key=2")
+df2.write.parquet("data/test_table/key=2")
// Read the partitioned table
-val df3 = sqlContext.parquetFile("data/test_table")
+val df3 = sqlContext.read.parquet("data/test_table")
df3.printSchema()
// The final schema consists of all 3 columns in the Parquet files together
@@ -1268,12 +1269,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
+This conversion can be done using `SQLContext.read.json()` on either an RDD of String,
+or a JSON file.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1284,8 +1283,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
val path = "examples/src/main/resources/people.json"
-// Create a DataFrame from the file(s) pointed to by path
-val people = sqlContext.jsonFile(path)
+val people = sqlContext.read.json(path)
// The inferred schema can be visualized using the printSchema() method.
people.printSchema()
@@ -1303,19 +1301,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age
// an RDD[String] storing one JSON object per string.
val anotherPeopleRDD = sc.parallelize(
"""{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil)
-val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
+val anotherPeople = sqlContext.read.json(anotherPeopleRDD)
{% endhighlight %}
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext` :
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
+This conversion can be done using `SQLContext.read().json()` on either an RDD of String,
+or a JSON file.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1325,9 +1321,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
-String path = "examples/src/main/resources/people.json";
-// Create a DataFrame from the file(s) pointed to by path
-DataFrame people = sqlContext.jsonFile(path);
+DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json");
// The inferred schema can be visualized using the printSchema() method.
people.printSchema();
@@ -1346,18 +1340,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN
List jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD anotherPeopleRDD = sc.parallelize(jsonData);
-DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD);
+DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD);
{% endhighlight %}
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
+This conversion can be done using `SQLContext.read.json` on a JSON file.
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
-
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1368,9 +1359,7 @@ sqlContext = SQLContext(sc)
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
-path = "examples/src/main/resources/people.json"
-# Create a DataFrame from the file(s) pointed to by path
-people = sqlContext.jsonFile(path)
+people = sqlContext.read.json("examples/src/main/resources/people.json")
# The inferred schema can be visualized using the printSchema() method.
people.printSchema()
@@ -1393,12 +1382,11 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using
+the `jsonFile` function, which loads data from a directory of JSON files where each line of the
+files is a JSON object.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1502,7 +1490,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and
-adds support for finding tables in the MetaStore and writing queries using HiveQL.
+adds support for finding tables in the MetaStore and writing queries using HiveQL.
{% highlight python %}
# sc is an existing SparkContext.
from pyspark.sql import HiveContext
@@ -1526,8 +1514,8 @@ adds support for finding tables in the MetaStore and writing queries using HiveQ
# sc is an existing SparkContext.
sqlContext <- sparkRHive.init(sc)
-hql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
-hql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
# Queries can be expressed in HiveQL.
results = sqlContext.sql("FROM src SELECT key, value").collect()
@@ -1537,6 +1525,70 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
+### Interacting with Different Versions of Hive Metastore
+
+One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore,
+which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below.
+
+Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET`
+and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive
+jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the
+version specified by users. An isolated classloader is used here to avoid dependency conflicts.
+
+
+
Property Name
Default
Meaning
+
+
spark.sql.hive.metastore.version
+
0.13.1
+
+ Version of the Hive metastore. Available
+ options are 0.12.0 and 0.13.1. Support for more versions is coming in the future.
+
+
+
+
spark.sql.hive.metastore.jars
+
builtin
+
+ Location of the jars that should be used to instantiate the HiveMetastoreClient. This
+ property can be one of three options:
+
+
builtin
+ Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is
+ enabled. When this option is chosen, spark.sql.hive.metastore.version must be
+ either 0.13.1 or not defined.
+
maven
+ Use Hive jars of specified version downloaded from Maven repositories.
+
A classpath in the standard format for both Hive and Hadoop.
+ A comma separated list of class prefixes that should be loaded using the classloader that is
+ shared between Spark SQL and a specific version of Hive. An example of classes that should
+ be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need
+ to be shared are those that interact with classes that are already shared. For example,
+ custom appenders that are used by log4j.
+
+
+
+
+
spark.sql.hive.metastore.barrierPrefixes
+
(empty)
+
+
+ A comma separated list of class prefixes that should explicitly be reloaded for each version
+ of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a
+ prefix that typically would be shared (i.e. org.apache.spark.*).
+
+
+
+
+
+
## JDBC To Other Databases
Spark SQL also includes a data source that can read data from other databases using JDBC. This
@@ -1570,7 +1622,7 @@ the Data Sources API. The following options are supported:
dbtable
- The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of
+ The JDBC table that should be read. Note that anything that is valid in a FROM clause of
a SQL query can be used. For example, instead of a full table you could also use a
subquery in parentheses.
@@ -1714,7 +1766,7 @@ that these options will be deprecated in future release as more optimizations ar
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
statistics are only supported for Hive Metastore tables where the command
- `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run.
+ ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
@@ -1737,7 +1789,9 @@ that these options will be deprecated in future release as more optimizations ar
# Distributed SQL Engine
-Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code.
+Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface.
+In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries,
+without the need to write any code.
## Running the Thrift JDBC/ODBC server
@@ -1816,6 +1870,25 @@ options.
## Upgrading from Spark SQL 1.3 to 1.4
+#### DataFrame data reader/writer interface
+
+Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`)
+and writing data out (`DataFrame.write`),
+and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`).
+
+See the API docs for `SQLContext.read` (
+ Scala,
+ Java,
+ Python
+) and `DataFrame.write` (
+ Scala,
+ Java,
+ Python
+) more information.
+
+
+#### DataFrame.groupBy retains grouping columns
+
Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index bd863d48d53e3..42b33947873b0 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -1946,10 +1946,10 @@ creates a single receiver (running on a worker machine) that receives a single s
Receiving multiple data streams can therefore be achieved by creating multiple input DStreams
and configuring them to receive different partitions of the data stream from the source(s).
For example, a single Kafka input DStream receiving two topics of data can be split into two
-Kafka input streams, each receiving only one topic. This would run two receivers on two workers,
-thus allowing data to be received in parallel, and increasing overall throughput. These multiple
-DStream can be unioned together to create a single DStream. Then the transformations that was
-being applied on the single input DStream can applied on the unified stream. This is done as follows.
+Kafka input streams, each receiving only one topic. This would run two receivers,
+allowing data to be received in parallel, and increasing overall throughput. These multiple
+DStreams can be unioned together to create a single DStream. Then the transformations that were
+being applied on a single input DStream can be applied on the unified stream. This is done as follows.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 29158d5c85651..dac649d1d5ae6 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -97,7 +97,7 @@ public static void main(String[] args) {
DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
DataFrame results = model2.transform(test);
diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py
new file mode 100644
index 0000000000000..6446f0fe5eeab
--- /dev/null
+++ b/examples/src/main/python/ml/gradient_boosted_trees.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import GBTClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import GBTRegressor
+from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline.
+Note: GBTClassifier only supports binary classification currently
+Run with:
+ bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py
+"""
+
+
+def testClassification(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = BinaryClassificationMetrics(predictionAndLabels)
+ print("AUC %.3f" % metrics.areaUnderROC)
+
+
+def testRegression(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: gradient_boosted_trees", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonGBTExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py
new file mode 100644
index 0000000000000..c7730e1bfacd9
--- /dev/null
+++ b/examples/src/main/python/ml/random_forest_example.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import RandomForestClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import RandomForestRegressor
+from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a RandomForest Classification/Regression Pipeline.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/random_forest_example.py
+"""
+
+
+def testClassification(train, test):
+ # Train a RandomForest model.
+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = MulticlassMetrics(predictionAndLabels)
+ print("weighted f-measure %.3f" % metrics.weightedFMeasure())
+ print("precision %s" % metrics.precision())
+ print("recall %s" % metrics.recall())
+
+
+def testRegression(train, test):
+ # Train a RandomForest model.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: random_forest_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonRandomForestExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
new file mode 100644
index 0000000000000..3933d59b52cd1
--- /dev/null
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import pprint
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import LogisticRegression
+from pyspark.mllib.linalg import DenseVector
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.sql import SQLContext
+
+"""
+A simple example demonstrating ways to specify parameters for Estimators and Transformers.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/simple_params_example.py
+"""
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: simple_params_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonSimpleParamsExample")
+ sqlContext = SQLContext(sc)
+
+ # prepare training data.
+ # We create an RDD of LabeledPoints and convert them into a DataFrame.
+ # Spark DataFrames can automatically infer the schema from named tuples
+ # and LabeledPoint implements __reduce__ to behave like a named tuple.
+ training = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])),
+ LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF()
+
+ # Create a LogisticRegression instance with maxIter = 10.
+ # This instance is an Estimator.
+ lr = LogisticRegression(maxIter=10)
+ # Print out the parameters, documentation, and any default values.
+ print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+ # We may also set parameters using setter methods.
+ lr.setRegParam(0.01)
+
+ # Learn a LogisticRegression model. This uses the parameters stored in lr.
+ model1 = lr.fit(training)
+
+ # Since model1 is a Model (i.e., a Transformer produced by an Estimator),
+ # we can view the parameters it used during fit().
+ # This prints the parameter (name: value) pairs, where names are unique IDs for this
+ # LogisticRegression instance.
+ print("Model 1 was fit using parameters:\n")
+ pprint.pprint(model1.extractParamMap())
+
+ # We may alternatively specify parameters using a parameter map.
+ # paramMap overrides all lr parameters set earlier.
+ paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"}
+
+ # Now learn a new model using the new parameters.
+ model2 = lr.fit(training, paramMap)
+ print("Model 2 was fit using parameters:\n")
+ pprint.pprint(model2.extractParamMap())
+
+ # prepare test data.
+ test = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])),
+ LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])),
+ LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF()
+
+ # Make predictions on test data using the Transformer.transform() method.
+ # LogisticRegressionModel.transform will only use the 'features' column.
+ # Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ # 'probability' column since we renamed the lr.probabilityCol parameter previously.
+ result = model2.transform(test) \
+ .select("features", "label", "myProbability", "prediction") \
+ .collect()
+
+ for row in result:
+ print("features=%s,label=%s -> prob=%s, prediction=%s"
+ % (row.features, row.label, row.myProbability, row.prediction))
+
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
index 32e02eab8b031..75c82117cbad2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
@@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._
/**
* Executes a roll up-style query against Apache logs.
- *
+ *
* Usage: LogQuery [logFile]
*/
object LogQuery {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index e8a991f50e338..a0561e2573fc9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -87,7 +87,7 @@ object SimpleParamsExample {
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
// Make predictions on test data using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test.toDF())
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b0613632c9946..3381941673db8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,7 +22,6 @@ import scala.language.reflectiveCalls
import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -354,7 +353,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
+ *
+ * This is just for demo purpose. In general, don't copy this code because it is NOT efficient
+ * due to the use of structural types, which leads to one reflection call per record.
*/
+ // scalastyle:off structural.type
private[mllib] def meanSquaredError(
model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
@@ -363,4 +366,5 @@ object DecisionTreeRunner {
err * err
}.mean()
}
+ // scalastyle:on structural.type
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
index 9a1aab036aa0f..f8c71ccabc43b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
@@ -41,22 +41,22 @@ object DenseGaussianMixture {
private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) {
val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
val ctx = new SparkContext(conf)
-
+
val data = ctx.textFile(inputFile).map { line =>
Vectors.dense(line.trim.split(' ').map(_.toDouble))
}.cache()
-
+
val clusters = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
.setMaxIterations(maxIterations)
.run(data)
-
+
for (i <- 0 until clusters.k) {
- println("weight=%f\nmu=%s\nsigma=\n%s\n" format
+ println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
-
+
println("Cluster labels (first <= 100):")
val clusterLabels = clusters.predict(data)
clusterLabels.take(100).foreach { x =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
index b336751d81616..813c8554f5193 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
@@ -40,7 +40,7 @@ object MQTTPublisher {
StreamingExamples.setStreamingLogLevels()
val Seq(brokerUrl, topic) = args.toSeq
-
+
var client: MqttClient = null
try {
@@ -59,10 +59,10 @@ object MQTTPublisher {
println(s"Published data. topic: ${msgtopic.getName()}; Message: $message")
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
- Thread.sleep(10)
+ Thread.sleep(10)
println("Queue is full, wait for to consume data from the message queue")
- }
- }
+ }
+ }
} catch {
case e: MqttException => println("Exception Caught: " + e)
} finally {
@@ -107,7 +107,7 @@ object MQTTWordCount {
val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2)
val words = lines.flatMap(x => x.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
-
+
wordCounts.print()
ssc.start()
ssc.awaitTermination()
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 1f3e619d97a24..71f2b6fe18bd1 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -42,15 +42,46 @@
org.apache.flumeflume-ng-sdk
+
+
+
+ com.google.guava
+ guava
+
+
+
+ org.apache.thrift
+ libthrift
+
+ org.apache.flumeflume-ng-core
+
+
+ com.google.guava
+ guava
+
+
+ org.apache.thrift
+ libthrift
+
+ org.scala-langscala-library
+
+
+ com.google.guava
+ guava
+ test
+
+
+
+
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
index fd01807fc3ac4..dc2a4ab138e18 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
@@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
-import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.flume.Channel
import org.apache.commons.lang3.RandomStringUtils
@@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils
private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel,
val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging {
val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads,
- new ThreadFactoryBuilder().setDaemon(true)
- .setNameFormat("Spark Sink Processor Thread - %d").build()))
+ new SparkSinkThreadFactory("Spark Sink Processor Thread - %d")))
// Protected by `sequenceNumberToProcessor`
private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]()
// This sink will not persist sequence numbers and reuses them if it gets restarted.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
similarity index 61%
rename from core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
rename to external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
index d75959f480756..845fc8debda75 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
@@ -14,11 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.streaming.flume.sink
-package org.apache.spark.util.collection
+import java.util.concurrent.ThreadFactory
+import java.util.concurrent.atomic.AtomicLong
-private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
- def hasNext: Boolean = iter.hasNext
+/**
+ * Thread factory that generates daemon threads with a specified name format.
+ */
+private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory {
+
+ private val threadId = new AtomicLong()
+
+ override def newThread(r: Runnable): Thread = {
+ val t = new Thread(r, nameFormat.format(threadId.incrementAndGet()))
+ t.setDaemon(true)
+ t
+ }
- def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
}
diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
index 650b2fbe1c142..fa43629d49771 100644
--- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
+++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
@@ -24,16 +24,24 @@ import scala.collection.JavaConversions._
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
-import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.flume.Context
import org.apache.flume.channel.MemoryChannel
import org.apache.flume.event.EventBuilder
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
+
+// Due to MNG-1378, there is not a way to include test dependencies transitively.
+// We cannot include Spark core tests as a dependency here because it depends on
+// Spark core main, which has too many dependencies to require here manually.
+// For this reason, we continue to use FunSuite and ignore the scalastyle checks
+// that fail if this is detected.
+//scalastyle:off
import org.scalatest.FunSuite
class SparkSinkSuite extends FunSuite {
+//scalastyle:on
+
val eventsPerBatch = 1000
val channelCapacity = 5000
@@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite {
count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = {
(1 to count).map(_ => {
- lazy val channelFactoryExecutor =
- Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true).
- setNameFormat("Flume Receiver Channel Thread - %d").build())
+ lazy val channelFactoryExecutor = Executors.newCachedThreadPool(
+ new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d"))
lazy val channelFactory =
new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor)
val transceiver = new NettyTransceiver(address, channelFactory)
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 8df7edbdcad33..a345c03582ad6 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming-flume-sink_${scala.binary.version}
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 60e2994431b38..1e32a365a1eee 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -152,9 +152,9 @@ class FlumeReceiver(
val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(),
Executors.newCachedThreadPool())
val channelPipelineFactory = new CompressionChannelPipelineFactory()
-
+
new NettyServer(
- responder,
+ responder,
new InetSocketAddress(host, port),
channelFactory,
channelPipelineFactory,
@@ -188,12 +188,12 @@ class FlumeReceiver(
override def preferredLocation: Option[String] = Option(host)
- /** A Netty Pipeline factory that will decompress incoming data from
+ /** A Netty Pipeline factory that will decompress incoming data from
* and the Netty client and compress data going back to the client.
*
* The compression on the return is required because Flume requires
- * a successful response to indicate it can remove the event/batch
- * from the configured channel
+ * a successful response to indicate it can remove the event/batch
+ * from the configured channel
*/
private[streaming]
class CompressionChannelPipelineFactory extends ChannelPipelineFactory {
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
index 92fa5b41be89e..583e7dca317ad 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
@@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver(
}
/**
- * A wrapper around the transceiver and the Avro IPC API.
+ * A wrapper around the transceiver and the Avro IPC API.
* @param transceiver The transceiver to use for communication with Flume
* @param client The client that the callbacks are received on.
*/
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 93afe50c2134f..d772b9ca9b570 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -31,16 +31,16 @@ import org.apache.flume.conf.Configurables
import org.apache.flume.event.EventBuilder
import org.scalatest.concurrent.Eventually._
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
import org.apache.spark.streaming.flume.sink._
import org.apache.spark.util.{ManualClock, Utils}
-class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging {
+class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val batchCount = 5
val eventsPerBatch = 100
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 39e6754c81dbf..c926359987d89 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -35,15 +35,15 @@ import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
import org.apache.spark.util.Utils
-class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
+class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
var ssc: StreamingContext = null
@@ -138,7 +138,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
val status = client.appendBatch(inputEvents.toList)
status should be (avro.Status.OK)
}
-
+
eventually(timeout(10 seconds), interval(100 milliseconds)) {
val outputEvents = outputBuffer.flatten.map { _.event }
outputEvents.foreach {
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 243ce6eaca658..5734d55bf4784 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.kafkakafka_${scala.binary.version}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index 6cf254a7b69cb..65d51d87f8486 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
r.flatMap { tm: TopicMetadata =>
tm.partitionsMetadata.map { pm: PartitionMetadata =>
TopicAndPartition(tm.topic, pm.partitionId)
- }
+ }
}
}
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 8be2707528d93..0b8a391a2c569 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -315,7 +315,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -363,7 +363,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -427,7 +427,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -489,7 +489,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index b6d314dfc7783..47bbfb605850a 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -28,10 +28,10 @@ import scala.language.postfixOps
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import kafka.serializer.StringDecoder
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.dstream.DStream
@@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.Utils
class DirectKafkaStreamSuite
- extends FunSuite
+ extends SparkFunSuite
with BeforeAndAfter
with BeforeAndAfterAll
with Eventually
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
index 7fb841b79cb65..d66830cbacdee 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka
import scala.util.Random
import kafka.common.TopicAndPartition
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll {
+import org.apache.spark.SparkFunSuite
+
+class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll {
private val topic = "kcsuitetopic" + Random.nextInt(10000)
private val topicAndPartition = TopicAndPartition(topic, 0)
private var kc: KafkaCluster = null
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 3c875cb766513..054487269a935 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -22,11 +22,11 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark._
-class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
+class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
private var kafkaTestUtils: KafkaTestUtils = _
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index 24699dfc33adb..8ee2cc660f849 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -23,14 +23,14 @@ import scala.language.postfixOps
import scala.util.Random
import kafka.serializer.StringDecoder
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll {
+class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
private var ssc: StreamingContext = _
private var kafkaTestUtils: KafkaTestUtils = _
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index 38548dd73b82c..80e2df62de3fe 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -26,15 +26,15 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.utils.{ZKGroupTopicDirs, ZkUtils}
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.util.Utils
-class ReliableKafkaStreamSuite extends FunSuite
+class ReliableKafkaStreamSuite extends SparkFunSuite
with BeforeAndAfterAll with BeforeAndAfter with Eventually {
private val sparkConf = new SparkConf()
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 98f95a9a64fa0..7d102e10ab60f 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.eclipse.pahoorg.eclipse.paho.client.mqttv3
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index a19a72c58a705..c4bf5aa7869bb 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
@@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils
-class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
+class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
private val batchDuration = Milliseconds(500)
private val master = "local[2]"
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 8b6a8959ac4cf..d28e3e1846d70 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.twitter4jtwitter4j-stream
diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
index 9ee57d7581d85..d9acb568879fe 100644
--- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
+++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
@@ -18,16 +18,16 @@
package org.apache.spark.streaming.twitter
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import twitter4j.Status
import twitter4j.auth.{NullAuthorization, Authorization}
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging {
+class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val batchDuration = Seconds(1)
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index a50d378b34335..9998c11c85171 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ ${akka.group}akka-zeromq_${scala.binary.version}
diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
index a7566e733d891..35d2e62c68480 100644
--- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
+++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq
import akka.actor.SupervisorStrategy
import akka.util.ByteString
import akka.zeromq.Subscribe
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class ZeroMQStreamSuite extends FunSuite {
+class ZeroMQStreamSuite extends SparkFunSuite {
val batchDuration = Seconds(1)
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
index 97c3476049289..be8b62d3cc6ba 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
@@ -119,7 +119,7 @@ object KinesisWordCountASL extends Logging {
val batchInterval = Milliseconds(2000)
// Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information
- // on sequence number of records that have been received. Same as batchInterval for this
+ // on sequence number of records that have been received. Same as batchInterval for this
// example.
val kinesisCheckpointInterval = batchInterval
@@ -145,7 +145,7 @@ object KinesisWordCountASL extends Logging {
// Map each word to a (word, 1) tuple so we can reduce by key to count the words
val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _)
-
+
// Print the first 10 wordCounts
wordCounts.print()
@@ -210,14 +210,14 @@ object KinesisWordProducerASL {
val randomWords = List("spark", "you", "are", "my", "father")
val totals = scala.collection.mutable.Map[String, Int]()
-
+
// Create the low-level Kinesis Client from the AWS Java SDK.
val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain())
kinesisClient.setEndpoint(endpoint)
println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" +
s" $recordsPerSecond records per second and $wordsPerRecord words per record")
-
+
// Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord
for (i <- 1 to 10) {
// Generate recordsPerSec records to put onto the stream
@@ -255,8 +255,8 @@ object KinesisWordProducerASL {
}
}
-/**
- * Utility functions for Spark Streaming examples.
+/**
+ * Utility functions for Spark Streaming examples.
* This has been lifted from the examples/ project to remove the circular dependency.
*/
private[streaming] object StreamingExamples extends Logging {
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
index 1c9b0c218ae18..83a4537559512 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
@@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock}
/**
* This is a helper class for managing checkpoint clocks.
*
- * @param checkpointInterval
+ * @param checkpointInterval
* @param currentClock. Default to current SystemClock if none is passed in (mocking purposes)
*/
private[kinesis] class KinesisCheckpointState(
- checkpointInterval: Duration,
+ checkpointInterval: Duration,
currentClock: Clock = new SystemClock())
extends Logging {
-
+
/* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */
val checkpointClock = new ManualClock()
checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds)
/**
- * Check if it's time to checkpoint based on the current time and the derived time
+ * Check if it's time to checkpoint based on the current time and the derived time
* for the next checkpoint
*
* @return true if it's time to checkpoint
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 7dd8bfdc2a6db..1a8a4cecc1141 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -44,12 +44,12 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* https://github.com/awslabs/amazon-kinesis-client
* This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here:
* http://spark.apache.org/docs/latest/streaming-custom-receivers.html
- * Instances of this class will get shipped to the Spark Streaming Workers to run within a
+ * Instances of this class will get shipped to the Spark Streaming Workers to run within a
* Spark Executor.
*
* @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams
* by the Kinesis Client Library. If you change the App name or Stream name,
- * the KCL will throw errors. This usually requires deleting the backing
+ * the KCL will throw errors. This usually requires deleting the backing
* DynamoDB table with the same name this Kinesis application.
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
@@ -87,7 +87,7 @@ private[kinesis] class KinesisReceiver(
*/
/**
- * workerId is used by the KCL should be based on the ip address of the actual Spark Worker
+ * workerId is used by the KCL should be based on the ip address of the actual Spark Worker
* where this code runs (not the driver's IP address.)
*/
private var workerId: String = null
@@ -121,7 +121,7 @@ private[kinesis] class KinesisReceiver(
/*
* RecordProcessorFactory creates impls of IRecordProcessor.
- * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the
+ * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the
* IRecordProcessor.processRecords() method.
* We're using our custom KinesisRecordProcessor in this case.
*/
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index f65e743c4e2a3..fe9e3a0c793e2 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -35,9 +35,9 @@ import com.amazonaws.services.kinesis.model.Record
/**
* Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor.
* This implementation operates on the Array[Byte] from the KinesisReceiver.
- * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each
- * shard in the Kinesis stream upon startup. This is normally done in separate threads,
- * but the KCLs within the KinesisReceivers will balance themselves out if you create
+ * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each
+ * shard in the Kinesis stream upon startup. This is normally done in separate threads,
+ * but the KCLs within the KinesisReceivers will balance themselves out if you create
* multiple Receivers.
*
* @param receiver Kinesis receiver
@@ -69,14 +69,14 @@ private[kinesis] class KinesisRecordProcessor(
* and Spark Streaming's Receiver.store().
*
* @param batch list of records from the Kinesis stream shard
- * @param checkpointer used to update Kinesis when this batch has been processed/stored
+ * @param checkpointer used to update Kinesis when this batch has been processed/stored
* in the DStream
*/
override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) {
if (!receiver.isStopped()) {
try {
/*
- * Notes:
+ * Notes:
* 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming
* Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the
* internally-configured Spark serializer (kryo, etc).
@@ -84,19 +84,19 @@ private[kinesis] class KinesisRecordProcessor(
* ourselves from Spark's internal serialization strategy.
* 3) For performance, the BlockGenerator is asynchronously queuing elements within its
* memory before creating blocks. This prevents the small block scenario, but requires
- * that you register callbacks to know when a block has been generated and stored
+ * that you register callbacks to know when a block has been generated and stored
* (WAL is sufficient for storage) before can checkpoint back to the source.
*/
batch.foreach(record => receiver.store(record.getData().array()))
-
+
logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId")
/*
- * Checkpoint the sequence number of the last record successfully processed/stored
+ * Checkpoint the sequence number of the last record successfully processed/stored
* in the batch.
* In this implementation, we're checkpointing after the given checkpointIntervalMillis.
- * Note that this logic requires that processRecords() be called AND that it's time to
- * checkpoint. I point this out because there is no background thread running the
+ * Note that this logic requires that processRecords() be called AND that it's time to
+ * checkpoint. I point this out because there is no background thread running the
* checkpointer. Checkpointing is tested and trigger only when a new batch comes in.
* If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below).
* However, if the worker dies unexpectedly, a checkpoint may not happen.
@@ -130,16 +130,16 @@ private[kinesis] class KinesisRecordProcessor(
}
} else {
/* RecordProcessor has been stopped. */
- logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" +
+ logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" +
s" and shardId $shardId. No more records will be processed.")
}
}
/**
* Kinesis Client Library is shutting down this Worker for 1 of 2 reasons:
- * 1) the stream is resharding by splitting or merging adjacent shards
+ * 1) the stream is resharding by splitting or merging adjacent shards
* (ShutdownReason.TERMINATE)
- * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason
+ * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason
* (ShutdownReason.ZOMBIE)
*
* @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE
@@ -153,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor(
* Checkpoint to indicate that all records from the shard have been drained and processed.
* It's now OK to read from the new shards that resulted from a resharding event.
*/
- case ShutdownReason.TERMINATE =>
+ case ShutdownReason.TERMINATE =>
KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
/*
diff --git a/graphx/pom.xml b/graphx/pom.xml
index d38a3aa8256b7..28b41228feb3d 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ com.google.guavaguava
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index cc70b396a8dd4..4611a3ace219b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -41,14 +41,16 @@ abstract class EdgeRDD[ED](
@transient sc: SparkContext,
@transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) {
+ // scalastyle:off structural.type
private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD }
+ // scalastyle:on structural.type
override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = {
val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context)
if (p.hasNext) {
- p.next._2.iterator.map(_.copy())
+ p.next()._2.iterator.map(_.copy())
} else {
Iterator.empty
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
index eb1dbe52c2fda..f1ecc9e2219d1 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
-class EdgeRDDSuite extends FunSuite with LocalSparkContext {
+class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext {
test("cache, getStorageLevel") {
// test to see if getStorageLevel returns correct value after caching
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
index 5a2c73b414279..094a63472eaab 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
@@ -17,21 +17,21 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class EdgeSuite extends FunSuite {
+class EdgeSuite extends SparkFunSuite {
test ("compare") {
// decending order
val testEdges: Array[Edge[Int]] = Array(
- Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1),
- Edge(0x2345L, 0x1234L, 1),
- Edge(0x1234L, 0x5678L, 1),
- Edge(0x1234L, 0x2345L, 1),
+ Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1),
+ Edge(0x2345L, 0x1234L, 1),
+ Edge(0x1234L, 0x5678L, 1),
+ Edge(0x1234L, 0x2345L, 1),
Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1)
)
// to ascending order
val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int])
-
+
for (i <- 0 until testEdges.length) {
assert(sortedEdges(i) == testEdges(testEdges.length - i - 1))
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index 68fe83739e399..57a8b95dd12e9 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.graphx
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.impl.EdgePartition
import org.apache.spark.rdd._
-import org.scalatest.FunSuite
-class GraphOpsSuite extends FunSuite with LocalSparkContext {
+class GraphOpsSuite extends SparkFunSuite with LocalSparkContext {
test("joinVertices") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 2b1d8e47326f8..1f5e27d5508b8 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.PartitionStrategy._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class GraphSuite extends FunSuite with LocalSparkContext {
+class GraphSuite extends SparkFunSuite with LocalSparkContext {
def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = {
Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v")
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
index 490b94429ea1f..8afa2d403b53f 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.rdd._
-class PregelSuite extends FunSuite with LocalSparkContext {
+class PregelSuite extends SparkFunSuite with LocalSparkContext {
test("1 iteration") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
index d0a7198d691d7..f1aa685a79c98 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
@@ -17,13 +17,11 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.{HashPartitioner, SparkContext}
+import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-class VertexRDDSuite extends FunSuite with LocalSparkContext {
+class VertexRDDSuite extends SparkFunSuite with LocalSparkContext {
private def vertices(sc: SparkContext, n: Int) = {
VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5))
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 515f3a9cd02eb..7435647c6d9ee 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl
import scala.reflect.ClassTag
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.graphx._
-class EdgePartitionSuite extends FunSuite {
+class EdgePartitionSuite extends SparkFunSuite {
def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = {
val builder = new EdgePartitionBuilder[A, Int]
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index fe8304c1cdc32..1203f8959f506 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -17,15 +17,13 @@
package org.apache.spark.graphx.impl
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.graphx._
-class VertexPartitionSuite extends FunSuite {
+class VertexPartitionSuite extends SparkFunSuite {
test("isDefined, filter") {
val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
index accccfc232cd3..c965a6eb8df13 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext {
test("Grid Connected Components") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
index 61fd0c4605568..808877f0590f8 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
-class LabelPropagationSuite extends FunSuite with LocalSparkContext {
+class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext {
test("Label Propagation") {
withSpark { sc =>
// Construct a graph with two cliques connected by a single edge
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
index 39c6ace912b00..45f1e3011035e 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
@@ -57,7 +56,7 @@ object GridPageRank {
}
-class PageRankSuite extends FunSuite with LocalSparkContext {
+class PageRankSuite extends SparkFunSuite with LocalSparkContext {
def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
index 7bd6b7f3c4ab2..2991438f5e57e 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
-class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
+class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext {
test("Test SVD++ with mean square error on training set") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
index f2c38e79c452c..d7eaa70ce6407 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.lib._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class ShortestPathsSuite extends FunSuite with LocalSparkContext {
+class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext {
test("Shortest Path Computations") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
index 1f658c371ffcf..d6b03208180db 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext {
test("Island Strongly Connected Components") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
index 79bf4e6cd18ee..c47552cf3a3bd 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut
-class TriangleCountSuite extends FunSuite with LocalSparkContext {
+class TriangleCountSuite extends SparkFunSuite with LocalSparkContext {
test("Count a single triangle") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
index f3b3738db0dad..186d0cc2a977b 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.graphx.util
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class BytecodeUtilsSuite extends FunSuite {
+class BytecodeUtilsSuite extends SparkFunSuite {
import BytecodeUtilsSuite.TestClass
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
index 8d9c8ddccbb3c..32e0c841c6997 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx.util
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.LocalSparkContext
-class GraphGeneratorsSuite extends FunSuite with LocalSparkContext {
+class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext {
test("GraphGenerators.generateRandomEdges") {
val src = 5
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 33fd813f7a86c..33d65d13f0d25 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -296,6 +296,9 @@ Properties loadPropertiesFile() throws IOException {
try {
fd = new FileInputStream(propsFile);
props.load(new InputStreamReader(fd, "UTF-8"));
+ for (Map.Entry
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming_${scala.binary.version}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index d8592eb2d947d..62f4b51f770e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -208,7 +208,7 @@ private[ml] object GBTClassificationModel {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index b8c7f3c5bc3b9..825f9ed1b54b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -37,11 +37,13 @@ import org.apache.spark.storage.StorageLevel
*/
private[ml] trait OneVsRestParams extends PredictorParams {
+ // scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
type M <: ClassificationModel[F, M]
type E <: Classifier[F, E, M]
}
+ // scalastyle:on structural.type
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -129,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ .drop(accColName)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 67600ebd7b38e..852a67e066322 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -170,7 +170,7 @@ private[ml] object RandomForestClassificationModel {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index eb6ec49f854be..8f34878c8d329 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -17,94 +17,152 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{DoubleType, StructType}
/**
* :: Experimental ::
- * A one-hot encoder that maps a column of label indices to a column of binary vectors, with
- * at most a single one-value. By default, the binary vector has an element for each category, so
- * with 5 categories, an input value of 2.0 would map to an output vector of
- * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
- * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
- * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
- * linearly dependent because they sum up to one.
+ * A one-hot encoder that maps a column of category indices to a column of binary vectors, with
+ * at most a single one-value per row that indicates the input category index.
+ * For example with 5 categories, an input value of 2.0 would map to an output vector of
+ * `[0.0, 0.0, 1.0, 0.0]`.
+ * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]]
+ * because it makes the vector entries sum up to one, and hence linearly dependent.
+ * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+ * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories.
+ * The output vectors are sparse.
+ *
+ * @see [[StringIndexer]] for converting categorical values into category indices
*/
@Experimental
-class OneHotEncoder(override val uid: String)
- extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol {
+class OneHotEncoder(override val uid: String) extends Transformer
+ with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("oneHot"))
/**
- * Whether to include a component in the encoded vectors for the first category, defaults to true.
+ * Whether to drop the last category in the encoded vector (default: true)
* @group param
*/
- final val includeFirst: BooleanParam =
- new BooleanParam(this, "includeFirst", "include first category")
- setDefault(includeFirst -> true)
-
- private var categories: Array[String] = _
+ final val dropLast: BooleanParam =
+ new BooleanParam(this, "dropLast", "whether to drop the last category")
+ setDefault(dropLast -> true)
/** @group setParam */
- def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
+ def setDropLast(value: Boolean): this.type = set(dropLast, value)
/** @group setParam */
- override def setInputCol(value: String): this.type = set(inputCol, value)
+ def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
- override def setOutputCol(value: String): this.type = set(outputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
- val inputFields = schema.fields
+ val is = "_is_"
+ val inputColName = $(inputCol)
val outputColName = $(outputCol)
- require(inputFields.forall(_.name != $(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
- categories = inputColAttr match {
+ SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
+ val inputFields = schema.fields
+ require(!inputFields.exists(_.name == outputColName),
+ s"Output column $outputColName already exists.")
+
+ val inputAttr = Attribute.fromStructField(schema(inputColName))
+ val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
- nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
- case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
+ if (nominal.values.isDefined) {
+ nominal.values.map(_.map(v => inputColName + is + v))
+ } else if (nominal.numValues.isDefined) {
+ nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
+ } else {
+ None
+ }
+ case binary: BinaryAttribute =>
+ if (binary.values.isDefined) {
+ binary.values.map(_.map(v => inputColName + is + v))
+ } else {
+ Some(Array.tabulate(2)(i => inputColName + is + i))
+ }
+ case _: NumericAttribute =>
+ throw new RuntimeException(
+ s"The input column $inputColName cannot be numeric.")
case _ =>
- throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
+ None // optimistic about unknown attributes
}
- val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
- val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
- val outputFields = inputFields :+ attr.toStructField()
+ val filteredOutputAttrNames = outputAttrNames.map { names =>
+ if ($(dropLast)) {
+ require(names.length > 1,
+ s"The input column $inputColName should have at least two distinct values.")
+ names.dropRight(1)
+ } else {
+ names
+ }
+ }
+
+ val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
+ val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
+ BinaryAttribute.defaultAttr.withName(name)
+ }
+ new AttributeGroup($(outputCol), attrs)
+ } else {
+ new AttributeGroup($(outputCol))
+ }
+
+ val outputFields = inputFields :+ outputAttrGroup.toStructField()
StructType(outputFields)
}
- protected override def createTransformFunc(): (Double) => Vector = {
- val first = $(includeFirst)
- val vecLen = if (first) categories.length else categories.length - 1
+ override def transform(dataset: DataFrame): DataFrame = {
+ // schema transformation
+ val is = "_is_"
+ val inputColName = $(inputCol)
+ val outputColName = $(outputCol)
+ val shouldDropLast = $(dropLast)
+ var outputAttrGroup = AttributeGroup.fromStructField(
+ transformSchema(dataset.schema)(outputColName))
+ if (outputAttrGroup.size < 0) {
+ // If the number of attributes is unknown, we check the values from the input column.
+ val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0))
+ .aggregate(0.0)(
+ (m, x) => {
+ assert(x >=0.0 && x == x.toInt,
+ s"Values from column $inputColName must be indices, but got $x.")
+ math.max(m, x)
+ },
+ (m0, m1) => {
+ math.max(m0, m1)
+ }
+ ).toInt + 1
+ val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
+ val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
+ val outputAttrs: Array[Attribute] =
+ filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
+ outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
+ }
+ val metadata = outputAttrGroup.toMetadata()
+
+ // data transformation
+ val size = outputAttrGroup.size
val oneValue = Array(1.0)
val emptyValues = Array[Double]()
val emptyIndices = Array[Int]()
- label: Double => {
- val values = if (first || label != 0.0) oneValue else emptyValues
- val indices = if (first) {
- Array(label.toInt)
- } else if (label != 0.0) {
- Array(label.toInt - 1)
+ val encode = udf { label: Double =>
+ if (label < size) {
+ Vectors.sparse(size, Array(label.toInt), oneValue)
} else {
- emptyIndices
+ Vectors.sparse(size, emptyIndices, emptyValues)
}
- Vectors.sparse(vecLen, indices, values)
}
- }
- /**
- * Returns the data type of the output column.
- */
- protected def outputDataType: DataType = new VectorUDT
+ dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index fdd2494fc87a6..b0fd06d84fdb3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -35,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
/**
* Centers the data with mean before scaling.
- * It will build a dense output, so this does not work on sparse input
+ * It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
*/
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
-
+
/**
* Scales the data to unit standard deviation.
* Default: true
@@ -68,13 +68,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
-
+
/** @group setParam */
def setWithMean(value: Boolean): this.type = set(withMean, value)
-
+
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
-
+
override def fit(dataset: DataFrame): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 69f4f5414c8c6..b7e374bb6cb49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -198,7 +198,7 @@ private[ml] object GBTRegressionModel {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 7c40db1a40040..fe2a71a331694 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -321,7 +321,7 @@ private class LeastSquaresAggregator(
}
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
}
-
+
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
private val gradientSumArray = Array.ofDim[Double](dim)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index ae767a17329d2..49a1f7ce8c995 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -152,7 +152,7 @@ private[ml] object RandomForestRegressionModel {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
new RandomForestRegressionModel(parent.uid, newTrees)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 65f30fdba7393..16f3131796709 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -399,7 +399,7 @@ private[python] class PythonMLLibAPI extends Serializable {
val sigma = si.map(_.asInstanceOf[DenseMatrix])
val gaussians = Array.tabulate(weight.length){
i => new MultivariateGaussian(mean(i), sigma(i))
- }
+ }
val model = new GaussianMixtureModel(weight, gaussians)
model.predictSoft(data).map(Vectors.dense)
}
@@ -494,7 +494,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
new Normalizer(p).transform(rdd)
}
-
+
/**
* Java stub for StandardScaler.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index e9a23e40cc790..70b0e40948e51 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -36,11 +36,11 @@ import org.apache.spark.util.Utils
* independent Gaussian distributions with associated "mixing" weights
* specifying each's contribution to the composite.
*
- * Given a set of sample points, this class will maximize the log-likelihood
- * for a mixture of k Gaussians, iterating until the log-likelihood changes by
+ * Given a set of sample points, this class will maximize the log-likelihood
+ * for a mixture of k Gaussians, iterating until the log-likelihood changes by
* less than convergenceTol, or until it has reached the max number of iterations.
* While this process is generally guaranteed to converge, it is not guaranteed
- * to find a global optimum.
+ * to find a global optimum.
*
* Note: For high-dimensional data (with many features), this algorithm may perform poorly.
* This is due to high-dimensional data (a) making it difficult to cluster at all (based
@@ -53,24 +53,24 @@ import org.apache.spark.util.Utils
*/
@Experimental
class GaussianMixture private (
- private var k: Int,
- private var convergenceTol: Double,
+ private var k: Int,
+ private var convergenceTol: Double,
private var maxIterations: Int,
private var seed: Long) extends Serializable {
-
+
/**
* Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01,
* maxIterations: 100, seed: random}.
*/
def this() = this(2, 0.01, 100, Utils.random.nextLong())
-
+
// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5
-
- // an initializing GMM can be provided rather than using the
+
+ // an initializing GMM can be provided rather than using the
// default random starting point
private var initialModel: Option[GaussianMixtureModel] = None
-
+
/** Set the initial GMM starting point, bypassing the random initialization.
* You must call setK() prior to calling this method, and the condition
* (model.k == this.k) must be met; failure will result in an IllegalArgumentException
@@ -83,37 +83,37 @@ class GaussianMixture private (
}
this
}
-
+
/** Return the user supplied initial GMM, if supplied */
def getInitialModel: Option[GaussianMixtureModel] = initialModel
-
+
/** Set the number of Gaussians in the mixture model. Default: 2 */
def setK(k: Int): this.type = {
this.k = k
this
}
-
+
/** Return the number of Gaussians in the mixture model */
def getK: Int = k
-
+
/** Set the maximum number of iterations to run. Default: 100 */
def setMaxIterations(maxIterations: Int): this.type = {
this.maxIterations = maxIterations
this
}
-
+
/** Return the maximum number of iterations to run */
def getMaxIterations: Int = maxIterations
-
+
/**
- * Set the largest change in log-likelihood at which convergence is
+ * Set the largest change in log-likelihood at which convergence is
* considered to have occurred.
*/
def setConvergenceTol(convergenceTol: Double): this.type = {
this.convergenceTol = convergenceTol
this
}
-
+
/**
* Return the largest change in log-likelihood at which convergence is
* considered to have occurred.
@@ -132,41 +132,41 @@ class GaussianMixture private (
/** Perform expectation maximization */
def run(data: RDD[Vector]): GaussianMixtureModel = {
val sc = data.sparkContext
-
+
// we will operate on the data as breeze data
val breezeData = data.map(_.toBreeze).cache()
-
+
// Get length of the input vectors
val d = breezeData.first().length
-
+
// Determine initial weights and corresponding Gaussians.
// If the user supplied an initial GMM, we use those values, otherwise
// we start with uniform weights, a random mean from the data, and
// diagonal covariance matrices using component variances
- // derived from the samples
+ // derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weights, gmm.gaussians)
-
+
case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
- (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
+ (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
- new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
+ new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
})
}
}
-
- var llh = Double.MinValue // current log-likelihood
+
+ var llh = Double.MinValue // current log-likelihood
var llhp = 0.0 // previous log-likelihood
-
+
var iter = 0
while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) {
// create and broadcast curried cluster contribution function
val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
-
+
// aggregate the cluster contribution for all sample points
val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
-
+
// Create new distributions based on the partial assignments
// (often referred to as the "M" step in literature)
val sumWeights = sums.weights.sum
@@ -179,22 +179,22 @@ class GaussianMixture private (
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
i = i + 1
}
-
+
llhp = llh // current becomes previous
llh = sums.logLikelihood // this is the freshly computed log-likelihood
iter += 1
- }
-
+ }
+
new GaussianMixtureModel(weights, gaussians)
}
-
+
/** Average of dense breeze vectors */
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
val v = BDV.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
- v / x.length.toDouble
+ v / x.length.toDouble
}
-
+
/**
* Construct matrix where diagonal entries are element-wise
* variance of input vectors (computes biased variance)
@@ -210,14 +210,14 @@ class GaussianMixture private (
// companion class to provide zero constructor for ExpectationSum
private object ExpectationSum {
def zero(k: Int, d: Int): ExpectationSum = {
- new ExpectationSum(0.0, Array.fill(k)(0.0),
+ new ExpectationSum(0.0, Array.fill(k)(0.0),
Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d)))
}
-
+
// compute cluster contributions for each input point
// (U, T) => U for aggregation
def add(
- weights: Array[Double],
+ weights: Array[Double],
dists: Array[MultivariateGaussian])
(sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
val p = weights.zip(dists).map {
@@ -235,7 +235,7 @@ private object ExpectationSum {
i = i + 1
}
sums
- }
+ }
}
// Aggregation class for partial expectation results
@@ -244,9 +244,9 @@ private class ExpectationSum(
val weights: Array[Double],
val means: Array[BDV[Double]],
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
-
+
val k = weights.length
-
+
def +=(x: ExpectationSum): ExpectationSum = {
var i = 0
while (i < k) {
@@ -257,5 +257,5 @@ private class ExpectationSum(
}
logLikelihood += x.logLikelihood
this
- }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 86353aed81156..5fc2cb1b62d33 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -34,10 +34,10 @@ import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
*
- * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
- * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
- * the respective mean and covariance for each Gaussian distribution i=1..k.
- *
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
+ * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
+ * the respective mean and covariance for each Gaussian distribution i=1..k.
+ *
* @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is
* the weight for Gaussian i, and weights.sum == 1
* @param gaussians Array of MultivariateGaussian where gaussians(i) represents
@@ -45,9 +45,9 @@ import org.apache.spark.sql.{SQLContext, Row}
*/
@Experimental
class GaussianMixtureModel(
- val weights: Array[Double],
+ val weights: Array[Double],
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
-
+
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
override protected def formatVersion = "1.0"
@@ -64,20 +64,20 @@ class GaussianMixtureModel(
val responsibilityMatrix = predictSoft(points)
responsibilityMatrix.map(r => r.indexOf(r.max))
}
-
+
/**
* Given the input vectors, return the membership value of each vector
- * to all mixture components.
+ * to all mixture components.
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val bcDists = sc.broadcast(gaussians)
val bcWeights = sc.broadcast(weights)
- points.map { x =>
+ points.map { x =>
computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}
-
+
/**
* Compute the partial assignments for each vector
*/
@@ -89,7 +89,7 @@ class GaussianMixtureModel(
val p = weights.zip(dists).map {
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt)
}
- val pSum = p.sum
+ val pSum = p.sum
for (i <- 0 until k) {
p(i) /= pSum
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 1ed01c9d8ba0b..e7a243f854e33 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] (
import org.apache.spark.mllib.clustering.PowerIterationClustering._
/** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100,
- * initMode: "random"}.
+ * initMode: "random"}.
*/
def this() = this(k = 2, maxIterations = 100, initMode = "random")
@@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging {
/**
* Generates random vertex properties (v0) to start power iteration.
- *
+ *
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing a random vector
* with unit 1-norm
@@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging {
* Generates the degree vector as the vertex properties (v0) to start power iteration.
* It is not exactly the node degrees but just the normalized sum similarities. Call it
* as degree vector because it is used in the PIC paper.
- *
+ *
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing the degree vector
*/
@@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging {
val v0 = g.vertices.mapValues(_ / sum)
GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges)
}
-
+
/**
* Runs power iteration.
* @param g input graph with edges representing the normalized affinity matrix (W) and vertices
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 812014a041719..c21e4fe7dc9b6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -178,7 +178,7 @@ class StreamingKMeans(
/** Set the decay factor directly (for forgetful algorithms). */
def setDecayFactor(a: Double): this.type = {
- this.decayFactor = decayFactor
+ this.decayFactor = a
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 9cc2d0ffcab7d..5f8c1dea237b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -108,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf
* (ordered by statistic value descending)
*/
@Experimental
-class ChiSqSelector (val numTopFeatures: Int) {
+class ChiSqSelector (val numTopFeatures: Int) extends Serializable {
/**
* Returns a ChiSquared feature selector.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 466ae95859b82..51546d41c36a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -42,7 +42,7 @@ import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.sql.{SQLContext, Row}
/**
- * Entry in vocabulary
+ * Entry in vocabulary
*/
private case class VocabWord(
var word: String,
@@ -56,18 +56,18 @@ private case class VocabWord(
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
- * and then learns vector representation of words in the vocabulary.
- * The vector representation can be used as features in
+ * and then learns vector representation of words in the vocabulary.
+ * The vector representation can be used as features in
* natural language processing and machine learning algorithms.
- *
- * We used skip-gram model in our implementation and hierarchical softmax
+ *
+ * We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* matches the original C implementation.
*
- * For original C implementation, see https://code.google.com/p/word2vec/
- * For research papers, see
+ * For original C implementation, see https://code.google.com/p/word2vec/
+ * For research papers, see
* Efficient Estimation of Word Representations in Vector Space
- * and
+ * and
* Distributed Representations of Words and Phrases and their Compositionality.
*/
@Experimental
@@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging {
private var numIterations = 1
private var seed = Utils.random.nextLong()
private var minCount = 5
-
+
/**
* Sets vector size (default: 100).
*/
@@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging {
this
}
- /**
- * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
+ /**
+ * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
* model's vocabulary (default: 5).
*/
def setMinCount(minCount: Int): this.type = {
this.minCount = minCount
this
}
-
+
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
@@ -150,13 +150,13 @@ class Word2Vec extends Serializable with Logging {
.map(x => VocabWord(
x._1,
x._2,
- new Array[Int](MAX_CODE_LENGTH),
- new Array[Int](MAX_CODE_LENGTH),
+ new Array[Int](MAX_CODE_LENGTH),
+ new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn)
-
+
vocabSize = vocab.length
require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
"the setting of minCount, which could be large enough to remove all your words in sentences.")
@@ -198,8 +198,8 @@ class Word2Vec extends Serializable with Logging {
}
var pos1 = vocabSize - 1
var pos2 = vocabSize
-
- var min1i = 0
+
+ var min1i = 0
var min2i = 0
a = 0
@@ -268,15 +268,15 @@ class Word2Vec extends Serializable with Logging {
val words = dataset.flatMap(x => x)
learnVocab(words)
-
+
createBinaryTree()
-
+
val sc = dataset.context
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
-
+
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext
@@ -297,7 +297,7 @@ class Word2Vec extends Serializable with Logging {
}
}
}
-
+
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
@@ -402,7 +402,7 @@ class Word2Vec extends Serializable with Logging {
}
}
newSentences.unpersist()
-
+
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
@@ -480,7 +480,7 @@ class Word2VecModel private[mllib] (
/**
* Transforms a word to its vector representation
- * @param word a word
+ * @param word a word
* @return vector representation of word
*/
def transform(word: String): Vector = {
@@ -495,7 +495,7 @@ class Word2VecModel private[mllib] (
/**
* Find synonyms of a word
* @param word a word
- * @param num number of synonyms to find
+ * @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
@@ -506,7 +506,7 @@ class Word2VecModel private[mllib] (
/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
- * @param num number of synonyms to find
+ * @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index ec38529cf8fae..557119f7b1cd1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
_nativeBLAS
}
-
+
/**
* A := alpha * x * x^T^ + A
* @param alpha a real scalar that will be multiplied to x * x^T^.
@@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging {
j += 1
}
i += 1
- }
+ }
}
private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
@@ -505,7 +505,7 @@ private[spark] object BLAS extends Serializable with Logging {
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}
-
+
/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A and `SparseVector` x.
@@ -557,7 +557,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
}
-
+
/**
* y := alpha * A * x + beta * y
* For `SparseMatrix` A and `SparseVector` x.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
index 866936aa4f118..ae3ba3099c878 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition {
require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE,
s"k = $k and/or n = $n are too large to compute an eigendecomposition")
-
+
var ido = new intW(0)
var info = new intW(0)
var resid = new Array[Double](n)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
index 34b447584e521..622b53a252ac5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
@@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
*/
private[mllib] class BinaryClassificationPMMLModelExport(
- model : GeneralizedLinearModel,
+ model : GeneralizedLinearModel,
description : String,
normalizationMethod : RegressionNormalizationMethodType,
- threshold: Double)
+ threshold: Double)
extends PMMLModelExport {
populateBinaryClassificationPMML()
@@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
.withUsageType(FieldUsageType.ACTIVE))
regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}
-
+
// add target field
val targetField = FieldName.create("target")
dataDictionary
@@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport(
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))
-
+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
-
+
pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
index ebdeae50bb32f..c5fdecd3ca17f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
@@ -25,7 +25,7 @@ import scala.beans.BeanProperty
import org.dmg.pmml.{Application, Header, PMML, Timestamp}
private[mllib] trait PMMLModelExport {
-
+
/**
* Holder of the exported model in PMML format
*/
@@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport {
val pmml: PMML = new PMML
setHeader(pmml)
-
+
private def setHeader(pmml: PMML): Unit = {
val version = getClass.getPackage.getImplementationVersion
val app = new Application().withName("Apache Spark MLlib").withVersion(version)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
index c16e83d6a067d..29bd689e1185a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
@@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.RidgeRegressionModel
private[mllib] object PMMLModelExportFactory {
-
+
/**
- * Factory object to help creating the necessary PMMLModelExport implementation
+ * Factory object to help creating the necessary PMMLModelExport implementation
* taking as input the machine learning model (for example KMeansModel).
*/
def createPMMLModelExport(model: Any): PMMLModelExport = {
@@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory {
new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
case svm: SVMModel =>
new BinaryClassificationPMMLModelExport(
- svm, "linear SVM", RegressionNormalizationMethodType.NONE,
+ svm, "linear SVM", RegressionNormalizationMethodType.NONE,
svm.getThreshold.getOrElse(0.0))
case logistic: LogisticRegressionModel =>
if (logistic.numClasses == 2) {
@@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory {
"PMML Export not supported for model: " + model.getClass.getName)
}
}
-
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index 7db5a14fd45a5..174d5e0f6c9f0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -234,7 +234,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param shape shape parameter (> 0) for the gamma distribution
- * @param scale scale parameter (> 0) for the gamma distribution
+ * @param scale scale parameter (> 0) for the gamma distribution
* @param size Size of the RDD.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
* @param seed Random seed (default: a random long integer).
@@ -293,7 +293,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param mean mean for the log normal distribution
- * @param std standard deviation for the log normal distribution
+ * @param std standard deviation for the log normal distribution
* @param size Size of the RDD.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
* @param seed Random seed (default: a random long integer).
@@ -671,7 +671,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param shape shape parameter (> 0) for the gamma distribution.
- * @param scale scale parameter (> 0) for the gamma distribution.
+ * @param scale scale parameter (> 0) for the gamma distribution.
* @param numRows Number of Vectors in the RDD.
* @param numCols Number of elements in each Vector.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index dddefe1944e9d..93290e6508529 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -175,7 +175,7 @@ class ALS private (
/**
* :: DeveloperApi ::
* Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default
- * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
+ * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
* `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement,
* at the cost of speed.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index 96e50faca2b19..f3b46c75c05f3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -170,15 +170,15 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
case class Data(boundary: Double, prediction: Double)
def save(
- sc: SparkContext,
- path: String,
- boundaries: Array[Double],
- predictions: Array[Double],
+ sc: SparkContext,
+ path: String,
+ boundaries: Array[Double],
+ predictions: Array[Double],
isotonic: Boolean): Unit = {
val sqlContext = new SQLContext(sc)
val metadata = compact(render(
- ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index cd6add9d60b0d..cf51b24ff777f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils
* the event that the covariance matrix is singular, the density will be computed in a
* reduced dimensional subspace under which the distribution is supported.
* (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
- *
+ *
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
*/
@DeveloperApi
class MultivariateGaussian (
- val mu: Vector,
+ val mu: Vector,
val sigma: Matrix) extends Serializable {
require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
-
+
private val breezeMu = mu.toBreeze.toDenseVector
-
+
/**
* private[mllib] constructor
- *
+ *
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
*/
private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = {
this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma))
}
-
+
/**
* Compute distribution dependent constants:
* rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
- * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
-
+
/** Returns density of this multivariate Gaussian at given point, x */
def pdf(x: Vector): Double = {
pdf(x.toBreeze)
}
-
+
/** Returns the log-density of this multivariate Gaussian at given point, x */
def logpdf(x: Vector): Double = {
logpdf(x.toBreeze)
}
-
+
/** Returns density of this multivariate Gaussian at given point, x */
private[mllib] def pdf(x: BV[Double]): Double = {
math.exp(logpdf(x))
}
-
+
/** Returns the log-density of this multivariate Gaussian at given point, x */
private[mllib] def logpdf(x: BV[Double]): Double = {
val delta = x - breezeMu
val v = rootSigmaInv * delta
u + v.t * v * -0.5
}
-
+
/**
* Calculate distribution dependent components used for the density function:
* pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
* where k is length of the mean vector.
- *
- * We here compute distribution-fixed parts
+ *
+ * We here compute distribution-fixed parts
* log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
* and
* D^(-1/2)^ * U, where sigma = U * D * U.t
- *
+ *
* Both the determinant and the inverse can be computed from the singular value decomposition
* of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
* we can use the eigendecomposition. We also do not compute the inverse directly; noting
- * that
- *
+ * that
+ *
* sigma = U * D * U.t
- * inv(Sigma) = U * inv(D) * U.t
+ * inv(Sigma) = U * inv(D) * U.t
* = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
- *
+ *
* and thus
- *
+ *
* -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^
- *
- * To guard against singular covariance matrices, this method computes both the
+ *
+ * To guard against singular covariance matrices, this method computes both the
* pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
* to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
* relation to the maximum singular value (same tolerance used by, e.g., Octave).
*/
private def calculateCovarianceConstants: (DBM[Double], Double) = {
val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
-
+
// For numerical stability, values are considered to be non-zero only if they exceed tol.
// This prevents any inverted value from exceeding (eps * n * max(d))^-1
val tol = MLUtils.EPSILON * max(d) * d.length
-
+
try {
// log(pseudo-determinant) is sum of the logs of all non-zero singular values
val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
-
- // calculate the root-pseudo-inverse of the diagonal matrix of singular values
+
+ // calculate the root-pseudo-inverse of the diagonal matrix of singular values
// by inverting the square root of all non-zero values
val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
-
+
(pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
} catch {
case uex: UnsupportedOperationException =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index e3ddc7053693c..a835f96d5d0e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging {
logInfo(s"$timer")
if (persistedInput) input.unpersist()
-
+
if (validate) {
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 99d0e3cf2fd6d..069959976a188 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -474,7 +474,7 @@ object RandomForest extends Serializable with Logging {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index ee710fc1ed299..a6d1398fc267b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -83,7 +83,7 @@ class Node (
def predict(features: Vector) : Double = {
if (isLeaf) {
predict.predict
- } else{
+ } else {
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 8676a8983cf2a..67d7d4d79f08b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -276,7 +276,7 @@ object MLUtils {
}
Vectors.fromBreeze(vector1)
}
-
+
/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
new file mode 100644
index 0000000000000..35b18c5308f61
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+public class JavaStringIndexerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ sqlContext = null;
+ }
+
+ @Test
+ public void testStringIndexer() {
+ StructType schema = createStructType(new StructField[] {
+ createStructField("id", IntegerType, false),
+ createStructField("label", StringType, false)
+ });
+ JavaRDD rdd = jsc.parallelize(
+ Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
+ DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+
+ StringIndexer indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex");
+ DataFrame output = indexer.fit(dataset).transform(dataset);
+
+ Assert.assertArrayEquals(
+ new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) },
+ output.orderBy("id").select("id", "labelIndex").collect());
+ }
+
+ /** An alias for RowFactory.create. */
+ private Row c(Object... values) {
+ return RowFactory.create(values);
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
new file mode 100644
index 0000000000000..b7c564caad3bd
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+public class JavaVectorAssemblerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void testVectorAssembler() {
+ StructType schema = createStructType(new StructField[] {
+ createStructField("id", IntegerType, false),
+ createStructField("x", DoubleType, false),
+ createStructField("y", new VectorUDT(), false),
+ createStructField("name", StringType, false),
+ createStructField("z", new VectorUDT(), false),
+ createStructField("n", LongType, false)
+ });
+ Row row = RowFactory.create(
+ 0, 0.0, Vectors.dense(1.0, 2.0), "a",
+ Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
+ JavaRDD rdd = jsc.parallelize(Arrays.asList(row));
+ DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+ VectorAssembler assembler = new VectorAssembler()
+ .setInputCols(new String[] {"x", "y", "z", "n"})
+ .setOutputCol("features");
+ DataFrame output = assembler.transform(dataset);
+ Assert.assertEquals(
+ Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
+ output.select("features").first().getAs(0));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
index 67c262d0f9d8d..928301523fba9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.util
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class IdentifiableSuite extends FunSuite {
+class IdentifiableSuite extends SparkFunSuite {
import IdentifiableSuite.Test
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 2b04a3034782e..05bf58e63abaf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.ml
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar.mock
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
-class PipelineSuite extends FunSuite {
+class PipelineSuite extends SparkFunSuite {
abstract class MyModel extends Model[MyModel]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 17ddd335deb6d..512cffb1acb66 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.attribute
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class AttributeGroupSuite extends FunSuite {
+class AttributeGroupSuite extends SparkFunSuite {
test("attribute group") {
val attrs = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index ec9b717e41ce8..72b575d022547 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.ml.attribute
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
-class AttributeSuite extends FunSuite {
+class AttributeSuite extends SparkFunSuite {
test("default numeric attribute") {
val attr: NumericAttribute = NumericAttribute.defaultAttr
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 3fdc66be8a314..ae40b0b8ff854 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
@@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeClassifierSuite.compareAPIs
@@ -251,7 +250,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
*/
}
-private[ml] object DecisionTreeClassifierSuite extends FunSuite {
+private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
/**
* Train 2 decision trees on the given dataset, one using the old API and one using the new API.
@@ -266,7 +265,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite {
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index ea86867f1161a..1302da3c373ff 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import GBTClassifierSuite.compareAPIs
@@ -128,7 +127,7 @@ private object GBTClassifierSuite {
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9f77d5f3efc55..a755cac3ea76e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@transient var binaryDataset: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 770b56890fa45..1d04ccb509057 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -30,7 +29,7 @@ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
@@ -94,6 +93,15 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
+
+ test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+ val logReg = new LogisticRegression()
+ .setMaxIter(1)
+ val ovr = new OneVsRest()
+ .setClassifier(logReg)
+ val output = ovr.fit(dataset).transform(dataset)
+ assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index cdbbacab8e0e3..eee9355a67be3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
@@ -32,7 +31,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestClassifierSuite.compareAPIs
@@ -158,7 +157,7 @@ private object RandomForestClassifierSuite {
data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newModel = rf.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 3ea7aad5274f2..36a1ac6b7996d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.ml.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
+class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Regression Evaluator: default params") {
/**
@@ -39,7 +38,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
val dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
-
+
/**
* Using the following R code to load the data, train the model and evaluate metrics.
*
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 8f6c6b39dc93b..7953bd0417191 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
+class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var data: Array[Double] = _
@@ -48,7 +47,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
test("Binarize continuous features with setter") {
val threshold: Double = 0.2
- val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+ val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
data.zip(thresholdBinarized)).toDF("feature", "expected")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 0391bd8427c2c..507a8a7db24c7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -19,15 +19,13 @@ package org.apache.spark.ml.feature
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
+class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
@@ -110,7 +108,7 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
}
}
-private object BucketizerSuite extends FunSuite {
+private object BucketizerSuite extends SparkFunSuite {
/** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
require(feature >= splits.head)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 2e4beb0bfff63..7b2d70e644005 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -26,7 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
val hashingTF = new HashingTF
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index f85e85471617a..d83772e8be755 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class IDFSuite extends FunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index 9d09f24709e23..9f03470b7f328 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var data: Array[Vector] = _
@transient var dataFrame: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 056b9eda86bba..2e5036a844562 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.col
-
-class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
+class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
@@ -36,15 +36,16 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
indexer.transform(df)
}
- test("OneHotEncoder includeFirst = true") {
+ test("OneHotEncoder dropLast = false") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setInputCol("labelIndex")
.setOutputCol("labelVec")
+ .setDropLast(false)
val encoded = encoder.transform(transformed)
val output = encoded.select("id", "labelVec").map { r =>
- val vec = r.get(1).asInstanceOf[Vector]
+ val vec = r.getAs[Vector](1)
(r.getInt(0), vec(0), vec(1), vec(2))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
@@ -53,22 +54,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
assert(output === expected)
}
- test("OneHotEncoder includeFirst = false") {
+ test("OneHotEncoder dropLast = true") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
- .setIncludeFirst(false)
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)
val output = encoded.select("id", "labelVec").map { r =>
- val vec = r.get(1).asInstanceOf[Vector]
+ val vec = r.getAs[Vector](1)
(r.getInt(0), vec(0), vec(1))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
- val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
- (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
+ val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
+ (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
assert(output === expected)
}
+ test("input column with ML attribute") {
+ val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
+ val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
+ .select(col("size").as("size", attr.toMetadata()))
+ val encoder = new OneHotEncoder()
+ .setInputCol("size")
+ .setOutputCol("encoded")
+ val output = encoder.transform(df)
+ val group = AttributeGroup.fromStructField(output.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
+ }
+
+ test("input column without ML attribute") {
+ val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
+ val encoder = new OneHotEncoder()
+ .setInputCol("index")
+ .setOutputCol("encoded")
+ val output = encoder.transform(df)
+ val group = AttributeGroup.fromStructField(output.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index aa230ca073d5b..feca866cd711d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,15 +17,15 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
import org.scalatest.exceptions.TestFailedException
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext {
+class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Polynomial expansion with default parameter") {
val data = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 89c2fe45573aa..cbf1e8ddcb48a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
+class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index eabda089d0988..ac279cb3215c2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -19,15 +19,14 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
-class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
+class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
test("RegexTokenizer") {
@@ -60,7 +59,7 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object RegexTokenizerSuite extends FunSuite {
+object RegexTokenizerSuite extends SparkFunSuite {
def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index 43534e89928b1..489abb5af7130 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
+class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index b11b029c6343e..06affc7305cf5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
import scala.beans.{BeanInfo, BeanProperty}
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
+class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
import VectorIndexerSuite.FeatureData
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index df446d0c22015..94ebc3aebfa37 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
-class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Word2Vec") {
val sqlContext = new SQLContext(sc)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 1505ad872536b..778abcba22c10 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -19,8 +19,7 @@ package org.apache.spark.ml.impl
import scala.collection.JavaConverters._
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.tree._
@@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame}
-private[ml] object TreeTests extends FunSuite {
+private[ml] object TreeTests extends SparkFunSuite {
/**
* Convert the given data to a DataFrame, and set the features and label metadata.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 04f2af4727ea4..f80e7749098a5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.param
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class ParamsSuite extends FunSuite {
+class ParamsSuite extends SparkFunSuite {
test("param") {
val solver = new TestParams()
@@ -202,7 +202,7 @@ class ParamsSuite extends FunSuite {
}
}
-object ParamsSuite extends FunSuite {
+object ParamsSuite extends SparkFunSuite {
/**
* Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
index ca18fa1ad3c15..eb5408d3fee7c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.ml.param.shared
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.Params
-class SharedParamsSuite extends FunSuite {
+class SharedParamsSuite extends SparkFunSuite {
test("outputCol") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 9a35555e52b90..2e5cfe7027eb6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
import com.github.fommil.netlib.BLAS.{getInstance => blas}
-import org.scalatest.FunSuite
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -36,7 +35,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.Utils
-class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
+class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
private var tempDir: File = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 1196a772dfdd4..33aa9d0d62343 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
@@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeRegressorSuite.compareAPIs
@@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
// TODO: test("model save/load") SPARK-6725
}
-private[ml] object DecisionTreeRegressorSuite extends FunSuite {
+private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
/**
* Train 2 decision trees on the given dataset, one using the old API and one using the new API.
@@ -83,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite {
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newTree = dt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 40e7e3273e965..98fb3d3f5f22c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import GBTRegressorSuite.compareAPIs
@@ -129,7 +128,7 @@ private object GBTRegressorSuite {
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTRegressionModel.fromOld(
oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 50a78631fa6d6..732e2c42be144 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 3efffbb763b78..b24ecaa57c89b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestRegressorSuite.compareAPIs
@@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
*/
}
-private object RandomForestRegressorSuite extends FunSuite {
+private object RandomForestRegressorSuite extends SparkFunSuite {
/**
* Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -114,7 +113,7 @@ private object RandomForestRegressorSuite extends FunSuite {
data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = rf.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestRegressionModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 60d8bfe38fb13..5ba469c7b10a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.tuning
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
-class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
+class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
index 20aa100112bfe..810b70049ec15 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.{ParamMap, TestParams}
-class ParamGridBuilderSuite extends FunSuite {
+class ParamGridBuilderSuite extends SparkFunSuite {
val solver = new TestParams()
import solver.{inputCol, maxIter}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index 3d362b5ee53ea..59944416d96a6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.api.python
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating
-class PythonMLLibAPISuite extends FunSuite {
+class PythonMLLibAPISuite extends SparkFunSuite {
SerDe.initialize()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 966811a5a3263..e8f3d0c4db20a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
import scala.util.Random
import scala.util.control.Breaks._
-import org.scalatest.FunSuite
import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -119,7 +119,7 @@ object LogisticRegressionSuite {
}
// Preventing the overflow when we compute the probability
val maxMargin = margins.max
- if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin
+ if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin
// Computing the probabilities for each class from the margins.
val norm = {
@@ -130,7 +130,7 @@ object LogisticRegressionSuite {
}
temp
}
- for (i <-0 until nClasses) probs(i) /= norm
+ for (i <- 0 until nClasses) probs(i) /= norm
// Compute the cumulative probability so we can generate a random number and assign a label.
for (i <- 1 until nClasses) probs(i) += probs(i - 1)
@@ -169,7 +169,7 @@ object LogisticRegressionSuite {
}
-class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
def validatePrediction(
predictions: Seq[Double],
input: Seq[LabeledPoint],
@@ -541,7 +541,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
}
-class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction using SGD optimizer") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index ea40b41bbbe5e..f7fc8730606af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -21,9 +21,8 @@ import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -86,7 +85,7 @@ object NaiveBayesSuite {
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
}
-class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
import NaiveBayes.{Multinomial, Bernoulli}
@@ -286,7 +285,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
+class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 90f9cec6855bf..b1d78cba9e3dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -21,9 +21,8 @@ import scala.collection.JavaConversions._
import scala.util.Random
import org.jblas.DoubleMatrix
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -62,7 +61,7 @@ object SVMSuite {
}
-class SVMSuite extends FunSuite with MLlibTestSparkContext {
+class SVMSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
+class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index 5683b55e8500a..e98b61e13e21f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase
-class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
+class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 30000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index f356ffa3e3a26..b218d72f1268a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.clustering
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
+class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
test("single cluster") {
val data = sc.parallelize(Array(
Vectors.dense(6.0, 9.0),
@@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
}
-
+
test("two clusters") {
val data = sc.parallelize(GaussianTestData.data)
@@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
-
+
val gmm = new GaussianMixture()
.setK(2)
.setInitialModel(initialGmm)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 877e6dc699523..0dbbd7127444f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class KMeansSuite extends FunSuite with MLlibTestSparkContext {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
@@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object KMeansSuite extends FunSuite {
+object KMeansSuite extends SparkFunSuite {
def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
val singlePoint = isSparse match {
case true =>
@@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite {
}
}
-class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
+class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index d5b7d96335744..406affa25539d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM}
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class LDASuite extends FunSuite with MLlibTestSparkContext {
+class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
import LDASuite._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 556842f3129a3..19e65f1b53ab5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
+class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.mllib.clustering.PowerIterationClustering._
@@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
predictions(a.cluster) += a.id
}
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
-
+
val model2 = new PowerIterationClustering()
.setK(2)
.setInitializationMode("degree")
@@ -130,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
}
}
-object PowerIterationClusteringSuite extends FunSuite {
+object PowerIterationClusteringSuite extends SparkFunSuite {
def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
val assignments = sc.parallelize(
(0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index f90025d535e45..ac01622b8a089 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.clustering
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.TestSuiteBase
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.random.XORShiftRandom
-class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
+class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
override def maxWaitTimeMillis: Int = 30000
@@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
assert(math.abs(c1) ~== 0.8 absTol 0.6)
}
+ test("SPARK-7946 setDecayFactor") {
+ val kMeans = new StreamingKMeans()
+ assert(kMeans.decayFactor === 1.0)
+ kMeans.setDecayFactor(2.0)
+ assert(kMeans.decayFactor === 2.0)
+ }
+
def StreamingKMeansDataGenerator(
numPoints: Int,
numBatches: Int,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
index 79847633ff0dc..87ccc7eda44ea 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext {
+class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext {
test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index e0224f960cc43..99d52fabc5309 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index 7dc4f3cfbc4e4..d55bc8c3ec09f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Multiclass evaluation metrics") {
/*
* Confusion matrix for 3-class classification with total 9 instances:
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
index 2537dd62c92f2..f3b19aeb42f84 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Multilabel evaluation metrics") {
/*
* Documents true labels (5x class0, 3x class1, 4x class2):
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
index 609eed983ff4e..c0924a213a844 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Ranking metrics: map, ndcg") {
val predictionAndLabels = sc.parallelize(
Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 3aa732474ec2e..9de2bdb6d7246 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("regression metrics") {
val predictionAndObservations = sc.parallelize(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index 747f5914598ec..889727fb55823 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext {
+class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
/*
* Contingency tables
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
index f3a482abda873..ccbf8a91cdd37 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext {
+class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext {
test("elementwise (hadamard) product should properly apply vector to dense data set") {
val denseData = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index 0c4dfb7b97c7f..cf279c02334e9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("hashing tf on a single doc") {
val hashingTF = new HashingTF(1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
index 0a5cad7caf8e4..21163633051e5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class IDFSuite extends FunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("idf") {
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
index 5c4af2b99e68b..34122d6ed2e95 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
import breeze.linalg.{norm => brzNorm}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
index 758af588f1c69..e57f49191378f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class PCASuite extends FunSuite with MLlibTestSparkContext {
+class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
private val data = Array(
Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index 1eb991869de40..6ab2fa6770123 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
import org.apache.spark.rdd.RDD
-class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
// When the input data is all constant, the variance is zero. The standardization against
// zero variance is not well-defined, but we decide to just set it into zero here.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 98a98a7599bcb..b6818369208d7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
// TODO: add more tests
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index bd5b9cc3afa10..66ae3543ecc4e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -16,11 +16,10 @@
*/
package org.apache.spark.mllib.fpm
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
+class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
test("FP-Growth using String type") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
index 04017f67c311d..a56d7b3579213 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm
import scala.language.existentials
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class FPTreeSuite extends FunSuite with MLlibTestSparkContext {
+class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
test("add transaction") {
val tree = new FPTree[String]
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index 699f009f0f2ec..d34888af2d73b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -17,18 +17,16 @@
package org.apache.spark.mllib.impl
-import org.scalatest.FunSuite
-
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext {
+class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
import PeriodicGraphCheckpointerSuite._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 64ecd12ea7ded..b0f3f71113c57 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.linalg.BLAS._
-class BLASSuite extends FunSuite {
+class BLASSuite extends SparkFunSuite {
test("copy") {
val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0))
@@ -140,7 +139,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dA)
assert(dA ~== expected absTol 1e-15)
-
+
val dB =
new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
@@ -149,7 +148,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dB)
}
}
-
+
val dC =
new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
@@ -158,7 +157,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dC)
}
}
-
+
val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
withClue("Size of vector must match the rank of matrix") {
@@ -256,13 +255,13 @@ class BLASSuite extends FunSuite {
val dA =
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
-
+
val dA2 =
new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true)
val sA2 =
new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0),
true)
-
+
val dx = new DenseVector(Array(1.0, 2.0, 3.0))
val sx = dx.toSparse
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
@@ -271,7 +270,7 @@ class BLASSuite extends FunSuite {
assert(sA.multiply(dx) ~== expected absTol 1e-15)
assert(dA.multiply(sx) ~== expected absTol 1e-15)
assert(sA.multiply(sx) ~== expected absTol 1e-15)
-
+
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
val y3 = y1.copy
@@ -288,7 +287,7 @@ class BLASSuite extends FunSuite {
val y14 = y1.copy
val y15 = y1.copy
val y16 = y1.copy
-
+
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
@@ -296,42 +295,42 @@ class BLASSuite extends FunSuite {
gemv(1.0, sA, dx, 2.0, y2)
gemv(1.0, dA, sx, 2.0, y3)
gemv(1.0, sA, sx, 2.0, y4)
-
+
gemv(1.0, dA2, dx, 2.0, y5)
gemv(1.0, sA2, dx, 2.0, y6)
gemv(1.0, dA2, sx, 2.0, y7)
gemv(1.0, sA2, sx, 2.0, y8)
-
+
gemv(2.0, dA, dx, 2.0, y9)
gemv(2.0, sA, dx, 2.0, y10)
gemv(2.0, dA, sx, 2.0, y11)
gemv(2.0, sA, sx, 2.0, y12)
-
+
gemv(2.0, dA2, dx, 2.0, y13)
gemv(2.0, sA2, dx, 2.0, y14)
gemv(2.0, dA2, sx, 2.0, y15)
gemv(2.0, sA2, sx, 2.0, y16)
-
+
assert(y1 ~== expected2 absTol 1e-15)
assert(y2 ~== expected2 absTol 1e-15)
assert(y3 ~== expected2 absTol 1e-15)
assert(y4 ~== expected2 absTol 1e-15)
-
+
assert(y5 ~== expected2 absTol 1e-15)
assert(y6 ~== expected2 absTol 1e-15)
assert(y7 ~== expected2 absTol 1e-15)
assert(y8 ~== expected2 absTol 1e-15)
-
+
assert(y9 ~== expected3 absTol 1e-15)
assert(y10 ~== expected3 absTol 1e-15)
assert(y11 ~== expected3 absTol 1e-15)
assert(y12 ~== expected3 absTol 1e-15)
-
+
assert(y13 ~== expected3 absTol 1e-15)
assert(y14 ~== expected3 absTol 1e-15)
assert(y15 ~== expected3 absTol 1e-15)
assert(y16 ~== expected3 absTol 1e-15)
-
+
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
gemv(1.0, dA.transpose, dx, 2.0, y1)
@@ -346,12 +345,12 @@ class BLASSuite extends FunSuite {
gemv(1.0, sA.transpose, sx, 2.0, y1)
}
}
-
+
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
-
+
val dATT = dAT.transpose
val sATT = sAT.transpose
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
index 2031032373971..dc04258e41d27 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM}
-class BreezeMatrixConversionSuite extends FunSuite {
+import org.apache.spark.SparkFunSuite
+
+class BreezeMatrixConversionSuite extends SparkFunSuite {
test("dense matrix to breeze") {
val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
val breeze = mat.toBreeze.asInstanceOf[BDM[Double]]
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
index 8abdac72902c6..3772c9235ad3a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+import org.apache.spark.SparkFunSuite
+
/**
* Test Breeze vector conversions.
*/
-class BreezeVectorConversionSuite extends FunSuite {
+class BreezeVectorConversionSuite extends SparkFunSuite {
val arr = Array(0.1, 0.2, 0.3, 0.4)
val n = 20
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 86119ec38101e..8dbb70f5d1c4c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg
import java.util.Random
import org.mockito.Mockito.when
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar._
import scala.collection.mutable.{Map => MutableMap}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
-class MatricesSuite extends FunSuite {
+class MatricesSuite extends SparkFunSuite {
test("dense matrix construction") {
val m = 3
val n = 2
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 24755e9ff46fc..c4ae0a16f7c04 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg
import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.util.TestingUtils._
-class VectorsSuite extends FunSuite {
+class VectorsSuite extends SparkFunSuite {
val arr = Array(0.1, 0.0, 0.3, 0.4)
val n = 4
@@ -215,13 +214,13 @@ class VectorsSuite extends FunSuite {
val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze)
- // SparseVector vs. SparseVector
- assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
+ // SparseVector vs. SparseVector
+ assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
// DenseVector vs. SparseVector
assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
// DenseVector vs. DenseVector
assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8)
- }
+ }
}
test("foreachActive") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index a58336175899c..93fe04c139b9a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed
import java.{util => ju}
import breeze.linalg.{DenseMatrix => BDM}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 5
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
index 04b36a9ef9990..f3728cd036a3f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.linalg.Vectors
-class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 5
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 2ab53cc13db71..4a7b99a976f0a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
-
import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vectors}
-class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 4
val n = 3
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index 27bb19f472e1e..b6cb53d0c743e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed
import scala.util.Random
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
-class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 4
val n = 3
@@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
+class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
var mat: RowMatrix = _
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index e110506d579b0..a5a59e9fad5ae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.mllib.optimization
import scala.collection.JavaConversions._
import scala.util.Random
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -61,7 +62,7 @@ object GradientDescentSuite {
}
}
-class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000
@@ -140,7 +141,7 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc
}
}
-class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
+class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index c8f2adcf155a7..d07b9d5b89227 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
val nPoints = 10000
val A = 2.0
@@ -229,7 +230,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
}
}
-class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small") {
val m = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
index bb723fc471181..d8f9b8c33963d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
-import org.scalatest.FunSuite
-
import org.jblas.{DoubleMatrix, SimpleBlas}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
-class NNLSSuite extends FunSuite {
+class NNLSSuite extends SparkFunSuite {
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = {
val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
index 0b646cf1ce6c4..4c6e76e47419b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionNormalizationMethodType
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.util.LinearDataGenerator
-class BinaryClassificationPMMLModelExportSuite extends FunSuite {
+class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite {
test("logistic regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
@@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
// ensure logistic regression has normalization method set to LOGIT
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
}
-
+
test("linear SVM PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
-
+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
-
+
// assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
val pmml = svmModelExport.getPmml
@@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
// ensure linear SVM has normalization method set to NONE
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
}
-
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
index f9afbd888dfc5..1d32309481787 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.RegressionModel
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator
-class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
+class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite {
test("linear regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
index b985d0446d7b0..b3f9750afa730 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.ClusteringModel
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
-class KMeansPMMLModelExportSuite extends FunSuite {
+class KMeansPMMLModelExportSuite extends SparkFunSuite {
test("KMeansPMMLModelExport generate PMML format") {
val clusterCenters = Array(
@@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite {
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
}
-
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
index f28a4ac8ad01f..af49450961750 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.pmml.export
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator
-class PMMLModelExportFactorySuite extends FunSuite {
+class PMMLModelExportFactorySuite extends SparkFunSuite {
test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
val clusterCenters = Array(
@@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite {
test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
+ "when passing a LogisticRegressionModel or SVMModel") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
-
+
val logisticRegressionModel =
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
val logisticRegressionModelExport =
PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
-
+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
}
-
+
test("PMMLModelExportFactory throw IllegalArgumentException "
+ "when passing a Multinomial Logistic Regression") {
/** 3 classes, 2 features */
val multiclassLogisticRegressionModel = new LogisticRegressionModel(
- weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
+ weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
numFeatures = 2, numClasses = 3)
-
+
intercept[IllegalArgumentException] {
PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index b792d819fdabb..a5ca1518f82f5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -19,12 +19,11 @@ package org.apache.spark.mllib.random
import scala.math
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.util.StatCounter
// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
-class RandomDataGeneratorSuite extends FunSuite {
+class RandomDataGeneratorSuite extends SparkFunSuite {
def apiChecks(gen: RandomDataGenerator[Double]) {
// resetting seed should generate the same sequence of random numbers
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
index 63f2ea916d457..413db2000d6d7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.random
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
@@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter
*
* TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
*/
-class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable {
+class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable {
def testGeneratedRDD(rdd: RDD[Double],
expectedSize: Long,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
index 57216e8eb4a55..10f5a2be48f7c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.rdd
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
-class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("topByKey") {
val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5,
1), (3, 5)), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 6d6c0aa5be812..bc64172614830 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.rdd
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.RDDFunctions._
-class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("sliding") {
val data = 0 until 6
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index b3798940ddc38..05b87728d6fdb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
import scala.math.abs
import scala.util.Random
-import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.storage.StorageLevel
@@ -84,7 +84,7 @@ object ALSSuite {
}
-class ALSSuite extends FunSuite with MLlibTestSparkContext {
+class ALSSuite extends SparkFunSuite with MLlibTestSparkContext {
test("rank-1 matrices") {
testALS(50, 100, 1, 15, 0.7, 0.3)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
index 2c92866f3893d..2c8ed057a516a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.recommendation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
-class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
+class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext {
val rank = 2
var userFeatures: RDD[(Int, Array[Double])] = _
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
index 3b38bdf5ef5eb..ea4f2865757c1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.mllib.regression
-import org.scalatest.{Matchers, FunSuite}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
private def round(d: Double) = {
math.round(d * 100).toDouble / 100
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
index 110c44a7193fd..d8364a06de4da 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.mllib.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
-class LabeledPointSuite extends FunSuite {
+class LabeledPointSuite extends SparkFunSuite {
test("parse labeled points") {
val points = Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 71dce50922991..08a152ffc7a23 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -32,7 +31,7 @@ private object LassoSuite {
val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class LassoSuite extends FunSuite with MLlibTestSparkContext {
+class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -143,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 3781931c2f819..f88a1c33c9f7c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -32,7 +31,7 @@ private object LinearRegressionSuite {
val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index d6c93cc0e49cd..7a781fee634c8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression
import scala.util.Random
import org.jblas.DoubleMatrix
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -33,7 +33,7 @@ private object RidgeRegressionSuite {
val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
predictions.zip(input).map { case (prediction, expected) =>
@@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 26604dbe6c1ef..9a379406d5061 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase
-class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
+class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 20000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index a7e6fce31ff7e..c292ced75e870 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.mllib.stat
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
SpearmanCorrelation}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
+class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
// test input data
val xData = Array(1.0, 0.0, -2.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
index 15418e6035965..b084a5fb4313f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat
import java.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.test.ChiSqTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext {
+class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("chi squared pearson goodness of fit") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
index a309c942cf8ff..5feccdf33681a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
@@ -18,11 +18,11 @@
package org.apache.spark.mllib.stat
import org.apache.commons.math3.distribution.NormalDistribution
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
+class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
test("kernel density single sample") {
val rdd = sc.parallelize(Array(5.0))
val evaluationPoints = Array(5.0, 6.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 23b0eec865de6..07efde4f5e6dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.stat
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.TestingUtils._
-class MultivariateOnlineSummarizerSuite extends FunSuite {
+class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
test("basic error handing") {
val summarizer = new MultivariateOnlineSummarizer
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
index fac2498e4dcb3..aa60deb665aeb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -17,49 +17,48 @@
package org.apache.spark.mllib.stat.distribution
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
+class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
test("univariate") {
val x1 = Vectors.dense(0.0)
val x2 = Vectors.dense(1.5)
-
+
val mu = Vectors.dense(0.0)
val sigma1 = Matrices.dense(1, 1, Array(1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
-
+
val sigma2 = Matrices.dense(1, 1, Array(4.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
}
-
+
test("multivariate") {
val x1 = Vectors.dense(0.0, 0.0)
val x2 = Vectors.dense(1.0, 1.0)
-
+
val mu = Vectors.dense(0.0, 0.0)
val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
-
+
val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
}
-
+
test("multivariate degenerate") {
val x1 = Vectors.dense(0.0, 0.0)
val x2 = Vectors.dense(1.0, 1.0)
-
+
val mu = Vectors.dense(0.0, 0.0)
val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
val dist = new MultivariateGaussian(mu, sigma)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index ce983eb27fa35..356d957f15909 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
-class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
// Tests examining individual elements of training
@@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object DecisionTreeSuite extends FunSuite {
+object DecisionTreeSuite extends SparkFunSuite {
def validateClassifier(
model: DecisionTreeModel,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 55b0bac7d49fe..84dd3b342d4c0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.mllib.tree
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
@@ -32,7 +31,7 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[GradientBoostedTrees]].
*/
-class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
+class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
GradientBoostedTreesSuite.testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
index 92b498580af03..49aff21fe7914 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.tree
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
* Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
*/
-class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext {
test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 4ed66953cb628..e6df5d974bf36 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -35,7 +34,7 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[RandomForest]].
*/
-class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
index b184e936672ca..9d756da410325 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.tree.impl
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
* Test suite for [[BaggedPoint]].
*/
-class BaggedPointSuite extends FunSuite with MLlibTestSparkContext {
+class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
test("BaggedPoint RDD: without subsampling") {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 27050bde16ef3..7bb9d570a0b89 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -64,7 +64,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
val fastSquaredDist3 =
fastSquaredDistance(v2, norm2, v3, norm3, precision)
assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m")
- if (m > 10) {
+ if (m > 10) {
val v4 = Vectors.sparse(n, indices.slice(0, m - 10),
indices.map(i => a(i) + 0.5).slice(0, m - 10))
val norm4 = Vectors.norm(v4, 2.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
index f68fb95eac4e4..8dcb9ba9be108 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
@@ -17,11 +17,9 @@
package org.apache.spark.mllib.util
-import org.scalatest.FunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.SparkException
-
-class NumericParserSuite extends FunSuite {
+class NumericParserSuite extends SparkFunSuite {
test("parser") {
val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)"
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
index 59e6c778806f4..8f475f30249d6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.util
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.scalatest.FunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.scalatest.exceptions.TestFailedException
-class TestingUtilsSuite extends FunSuite {
+class TestingUtilsSuite extends SparkFunSuite {
test("Comparing doubles using relative error.") {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 11b439e7875fc..8da72b3fa7cdb 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -38,6 +38,8 @@ object MimaExcludes {
Seq(
MimaBuild.excludeSparkPackage("deploy"),
MimaBuild.excludeSparkPackage("ml"),
+ // SPARK-7910 Adding a method to get the partioner to JavaRDD,
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"),
// SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
// These are needed if checking against the sbt build, since they are part of
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b9515a12bc573..9a849639233bc 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -23,7 +23,6 @@ import scala.collection.JavaConversions._
import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
-import sbtunidoc.Plugin.genjavadocSettings
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys}
import net.virtualvoid.sbt.graph.Plugin.graphSettings
@@ -118,7 +117,12 @@ object SparkBuild extends PomBuild {
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
- lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq (
+ lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq(
+ libraryDependencies += compilerPlugin(
+ "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full),
+ scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java")))
+
+ lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq (
javaHome := sys.env.get("JAVA_HOME")
.orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() })
.map(file),
@@ -126,7 +130,7 @@ object SparkBuild extends PomBuild {
retrieveManaged := true,
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
publishMavenStyle := true,
- unidocGenjavadocVersion := "0.8",
+ unidocGenjavadocVersion := "0.9-spark0",
resolvers += Resolver.mavenLocal,
otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))),
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 7096b0d3ee7de..75bd604a1b857 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -25,7 +25,7 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1")
-addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1")
+addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3")
addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2")
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 0d21a132048a5..adca90ddaf397 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -261,3 +261,7 @@ def _start_update_server():
thread.daemon = True
thread.start()
return server
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index b0479d9b074db..ddb33f427ac64 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -324,65 +324,73 @@ def getP(self):
@inherit_doc
class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
"""
- A one-hot encoder that maps a column of label indices to a column of binary vectors, with
- at most a single one-value. By default, the binary vector has an element for each category, so
- with 5 categories, an input value of 2.0 would map to an output vector of
- (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so
- the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
- of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
- linearly dependent because they sum up to one.
-
- TODO: This method requires the use of StringIndexer first. Decouple them.
+ A one-hot encoder that maps a column of category indices to a
+ column of binary vectors, with at most a single one-value per row
+ that indicates the input category index.
+ For example with 5 categories, an input value of 2.0 would map to
+ an output vector of `[0.0, 0.0, 1.0, 0.0]`.
+ The last category is not included by default (configurable via
+ :py:attr:`dropLast`) because it makes the vector entries sum up to
+ one, and hence linearly dependent.
+ So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+ Note that this is different from scikit-learn's OneHotEncoder,
+ which keeps all categories.
+ The output vectors are sparse.
+
+ .. seealso::
+
+ :py:class:`StringIndexer` for converting categorical values into
+ category indices
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> model = stringIndexer.fit(stringIndDf)
>>> td = model.transform(stringIndDf)
- >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features")
+ >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features")
>>> encoder.transform(td).head().features
- SparseVector(2, {})
+ SparseVector(2, {0: 1.0})
>>> encoder.setParams(outputCol="freqs").transform(td).head().freqs
- SparseVector(2, {})
- >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"}
+ SparseVector(2, {0: 1.0})
+ >>> params = {encoder.dropLast: False, encoder.outputCol: "test"}
>>> encoder.transform(td, params).head().test
SparseVector(3, {0: 1.0})
"""
# a placeholder to make it appear in the generated doc
- includeFirst = Param(Params._dummy(), "includeFirst", "include first category")
+ dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category")
@keyword_only
- def __init__(self, includeFirst=True, inputCol=None, outputCol=None):
+ def __init__(self, dropLast=True, inputCol=None, outputCol=None):
"""
__init__(self, includeFirst=True, inputCol=None, outputCol=None)
"""
super(OneHotEncoder, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
- self.includeFirst = Param(self, "includeFirst", "include first category")
- self._setDefault(includeFirst=True)
+ self.dropLast = Param(self, "dropLast", "whether to drop the last category")
+ self._setDefault(dropLast=True)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
- def setParams(self, includeFirst=True, inputCol=None, outputCol=None):
+ def setParams(self, dropLast=True, inputCol=None, outputCol=None):
"""
- setParams(self, includeFirst=True, inputCol=None, outputCol=None)
+ setParams(self, dropLast=True, inputCol=None, outputCol=None)
Sets params for this OneHotEncoder.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
- def setIncludeFirst(self, value):
+ def setDropLast(self, value):
"""
- Sets the value of :py:attr:`includeFirst`.
+ Sets the value of :py:attr:`dropLast`.
"""
- self._paramMap[self.includeFirst] = value
+ self._paramMap[self.dropLast] = value
return self
- def getIncludeFirst(self):
+ def getDropLast(self):
"""
- Gets the value of includeFirst or its default value.
+ Gets the value of dropLast or its default value.
"""
- return self.getOrDefault(self.includeFirst)
+ return self.getOrDefault(self.dropLast)
@inherit_doc
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 497841b6c8ce6..0bf988fd72f14 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -91,20 +91,19 @@ class CrossValidator(Estimator):
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.mllib.linalg import Vectors
>>> dataset = sqlContext.createDataFrame(
- ... [(Vectors.dense([0.0, 1.0]), 0.0),
- ... (Vectors.dense([1.0, 2.0]), 1.0),
- ... (Vectors.dense([0.55, 3.0]), 0.0),
- ... (Vectors.dense([0.45, 4.0]), 1.0),
- ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
+ ... [(Vectors.dense([0.0]), 0.0),
+ ... (Vectors.dense([0.4]), 1.0),
+ ... (Vectors.dense([0.5]), 0.0),
+ ... (Vectors.dense([0.6]), 1.0),
+ ... (Vectors.dense([1.0]), 1.0)] * 10,
... ["features", "label"])
>>> lr = LogisticRegression()
- >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- >>> # SPARK-7432: The following test is flaky.
- >>> # cvModel = cv.fit(dataset)
- >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
- >>> # cvModel.transform(dataset).collect() == expected.collect()
+ >>> cvModel = cv.fit(dataset)
+ >>> evaluator.evaluate(cvModel.transform(dataset))
+ 0.8333...
"""
# a placeholder to make it appear in the generated doc
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index 07507b2ad0d05..b11aed2c3afda 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -28,11 +28,3 @@
__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
'recommendation', 'regression', 'stat', 'tree', 'util']
-
-import sys
-from . import rand as random
-modname = __name__ + '.random'
-random.__name__ = modname
-random.RandomRDDs.__module__ = modname
-sys.modules[modname] = random
-del modname, sys
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index aab5e5f4b77b5..c5cf3a4e7ff22 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
Evaluator for binary classification.
+ :param scoreAndLabels: an RDD of (score, label) pairs
+
>>> scoreAndLabels = sc.parallelize([
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
>>> metrics = BinaryClassificationMetrics(scoreAndLabels)
@@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
def __init__(self, scoreAndLabels):
- """
- :param scoreAndLabels: an RDD of (score, label) pairs
- """
sc = scoreAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
@@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper):
"""
Evaluator for regression.
+ :param predictionAndObservations: an RDD of (prediction,
+ observation) pairs.
+
>>> predictionAndObservations = sc.parallelize([
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
>>> metrics = RegressionMetrics(predictionAndObservations)
@@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndObservations):
- """
- :param predictionAndObservations: an RDD of (prediction, observation) pairs.
- """
sc = predictionAndObservations.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
@@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.
+ :param predictionAndLabels an RDD of (prediction, label) pairs.
+
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
@@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels an RDD of (prediction, label) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
@@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper):
"""
Evaluator for ranking algorithms.
+ :param predictionAndLabels: an RDD of (predicted ranking,
+ ground truth set) pairs.
+
>>> predictionAndLabels = sc.parallelize([
... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
@@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
@@ -347,6 +345,10 @@ class MultilabelMetrics(JavaModelWrapper):
"""
Evaluator for multilabel classification.
+ :param predictionAndLabels: an RDD of (predictions, labels) pairs,
+ both are non-null Arrays, each with
+ unique elements.
+
>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index aac305db6c19a..da90554f41437 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -68,6 +68,8 @@ class Normalizer(VectorTransformer):
For `p` = float('inf'), max(abs(vector)) will be used as norm for
normalization.
+ :param p: Normalization in L^p^ space, p = 2 by default.
+
>>> v = Vectors.dense(range(3))
>>> nor = Normalizer(1)
>>> nor.transform(v)
@@ -82,9 +84,6 @@ class Normalizer(VectorTransformer):
DenseVector([0.0, 0.5, 1.0])
"""
def __init__(self, p=2.0):
- """
- :param p: Normalization in L^p^ space, p = 2 by default.
- """
assert p >= 1.0, "p should be greater than 1.0"
self.p = float(p)
@@ -94,7 +93,7 @@ def transform(self, vector):
:param vector: vector or RDD of vector to be normalized.
:return: normalized vector. If the norm of the input is zero, it
- will return the input vector.
+ will return the input vector.
"""
sc = SparkContext._active_spark_context
assert sc is not None, "SparkContext should be initialized first"
@@ -164,6 +163,13 @@ class StandardScaler(object):
variance using column summary statistics on the samples in the
training set.
+ :param withMean: False by default. Centers the data with mean
+ before scaling. It will build a dense output, so this
+ does not work on sparse input and will raise an
+ exception.
+ :param withStd: True by default. Scales the data to unit
+ standard deviation.
+
>>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])]
>>> dataset = sc.parallelize(vs)
>>> standardizer = StandardScaler(True, True)
@@ -174,14 +180,6 @@ class StandardScaler(object):
DenseVector([0.7071, -0.7071, 0.7071])
"""
def __init__(self, withMean=False, withStd=True):
- """
- :param withMean: False by default. Centers the data with mean
- before scaling. It will build a dense output, so this
- does not work on sparse input and will raise an
- exception.
- :param withStd: True by default. Scales the data to unit
- standard deviation.
- """
if not (withMean or withStd):
warnings.warn("Both withMean and withStd are false. The model does nothing.")
self.withMean = withMean
@@ -193,7 +191,7 @@ def fit(self, dataset):
for later scaling.
:param data: The data used to compute the mean and variance
- to build the transformation model.
+ to build the transformation model.
:return: a StandardScalarModel
"""
dataset = dataset.map(_convert_to_vector)
@@ -223,6 +221,8 @@ class ChiSqSelector(object):
Creates a ChiSquared feature selector.
+ :param numTopFeatures: number of features that selector will select.
+
>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
@@ -236,9 +236,6 @@ class ChiSqSelector(object):
DenseVector([5.0])
"""
def __init__(self, numTopFeatures):
- """
- :param numTopFeatures: number of features that selector will select.
- """
self.numTopFeatures = int(numTopFeatures)
def fit(self, data):
@@ -246,9 +243,9 @@ def fit(self, data):
Returns a ChiSquared feature selector.
:param data: an `RDD[LabeledPoint]` containing the labeled dataset
- with categorical features. Real-valued features will be
- treated as categorical for each distinct value.
- Apply feature discretizer before using this function.
+ with categorical features. Real-valued features will be
+ treated as categorical for each distinct value.
+ Apply feature discretizer before using this function.
"""
jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
return ChiSqSelectorModel(jmodel)
@@ -263,15 +260,14 @@ class HashingTF(object):
Note: the terms must be hashable (can not be dict/set/list...).
+ :param numFeatures: number of features (default: 2^20)
+
>>> htf = HashingTF(100)
>>> doc = "a a b b c d".split(" ")
>>> htf.transform(doc)
SparseVector(100, {...})
"""
def __init__(self, numFeatures=1 << 20):
- """
- :param numFeatures: number of features (default: 2^20)
- """
self.numFeatures = numFeatures
def indexOf(self, term):
@@ -311,7 +307,7 @@ def transform(self, x):
Call transform directly on the RDD instead.
:param x: an RDD of term frequency vectors or a term frequency
- vector
+ vector
:return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
if isinstance(x, RDD):
@@ -342,6 +338,9 @@ class IDF(object):
`minDocFreq`). For terms that are not in at least `minDocFreq`
documents, the IDF is found as 0, resulting in TF-IDFs of 0.
+ :param minDocFreq: minimum of documents in which a term
+ should appear for filtering
+
>>> n = 4
>>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)),
... Vectors.dense([0.0, 1.0, 2.0, 3.0]),
@@ -362,10 +361,6 @@ class IDF(object):
SparseVector(4, {1: 0.0, 3: 0.5754})
"""
def __init__(self, minDocFreq=0):
- """
- :param minDocFreq: minimum of documents in which a term
- should appear for filtering
- """
self.minDocFreq = minDocFreq
def fit(self, dataset):
diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py
similarity index 100%
rename from python/pyspark/mllib/rand.py
rename to python/pyspark/mllib/random.py
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 8fee92ae3aed5..726d288d97b2e 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -50,18 +50,6 @@ def deco(f):
return f
return deco
-# fix the module name conflict for Python 3+
-import sys
-from . import _types as types
-modname = __name__ + '.types'
-types.__name__ = modname
-# update the __module__ for all objects, make them picklable
-for v in types.__dict__.values():
- if hasattr(v, "__module__") and v.__module__.endswith('._types'):
- v.__module__ = modname
-sys.modules[modname] = types
-del modname, sys
-
from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.column import Column
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 8dc5039f587f0..1ecec5b126505 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound):
"""
A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
+
+ >>> df.select(df.name, df.age.between(2, 4)).show()
+ +-----+--------------------------+
+ | name|((age >= 2) && (age <= 4))|
+ +-----+--------------------------+
+ |Alice| true|
+ | Bob| false|
+ +-----+--------------------------+
"""
return (self >= lowerBound) & (self <= upperBound)
@@ -328,12 +336,20 @@ def when(self, condition, value):
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+ +-----+--------------------------------------------------------+
+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
+ +-----+--------------------------------------------------------+
+ |Alice| -1|
+ | Bob| 1|
+ +-----+--------------------------------------------------------+
"""
- sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
- jc = sc._jvm.functions.when(condition._jc, v)
+ jc = self._jc.when(condition._jc, v)
return Column(jc)
@since(1.4)
@@ -345,9 +361,18 @@ def otherwise(self, value):
See :func:`pyspark.sql.functions.when` for example usage.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
+ +-----+---------------------------------+
+ | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
+ +-----+---------------------------------+
+ |Alice| 0|
+ | Bob| 1|
+ +-----+---------------------------------+
"""
v = value._jc if isinstance(value, Column) else value
- jc = self._jc.otherwise(value)
+ jc = self._jc.otherwise(v)
return Column(jc)
@since(1.4)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 936487519a645..a82b6b87c413e 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1170,6 +1170,9 @@ def freqItems(self, cols, support=None):
"http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
:func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases.
+ This function is meant for exploratory data analysis, as we make no guarantee about the
+ backward compatibility of the schema of the resulting DataFrame.
+
:param cols: Names of the columns to calculate frequent items for as a list or tuple of
strings.
:param support: The frequency with which to consider an item 'frequent'. Default is 1%.
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b6fd413bec7db..d17d87419fe3d 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -43,6 +43,39 @@ def _df(self, jdf):
from pyspark.sql.dataframe import DataFrame
return DataFrame(jdf, self._sqlContext)
+ @since(1.4)
+ def format(self, source):
+ """
+ Specifies the input data source format.
+ """
+ self._jreader = self._jreader.format(source)
+ return self
+
+ @since(1.4)
+ def schema(self, schema):
+ """
+ Specifies the input schema. Some data sources (e.g. JSON) can
+ infer the input schema automatically from data. By specifying
+ the schema here, the underlying data source can skip the schema
+ inference step, and thus speed up data loading.
+
+ :param schema: a StructType object
+ """
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+ self._jreader = self._jreader.schema(jschema)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """
+ Adds input options for the underlying data source.
+ """
+ for k in options:
+ self._jreader = self._jreader.option(k, options[k])
+ return self
+
@since(1.4)
def load(self, path=None, format=None, schema=None, **options):
"""Loads data from a data source and returns it as a :class`DataFrame`.
@@ -52,20 +85,15 @@ def load(self, path=None, format=None, schema=None, **options):
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
"""
- jreader = self._jreader
if format is not None:
- jreader = jreader.format(format)
+ self.format(format)
if schema is not None:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jreader = jreader.schema(jschema)
- for k in options:
- jreader = jreader.option(k, options[k])
+ self.schema(schema)
+ self.options(**options)
if path is not None:
- return self._df(jreader.load(path))
+ return self._df(self._jreader.load(path))
else:
- return self._df(jreader.load())
+ return self._df(self._jreader.load())
@since(1.4)
def json(self, path, schema=None):
@@ -105,12 +133,9 @@ def json(self, path, schema=None):
| |-- field5: array (nullable = true)
| | |-- element: integer (containsNull = true)
"""
- if schema is None:
- jdf = self._jreader.json(path)
- else:
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jdf = self._jreader.schema(jschema).json(path)
- return self._df(jdf)
+ if schema is not None:
+ self.schema(schema)
+ return self._df(self._jreader.json(path))
@since(1.4)
def table(self, tableName):
@@ -194,6 +219,51 @@ def __init__(self, df):
self._sqlContext = df.sql_ctx
self._jwrite = df._jdf.write()
+ @since(1.4)
+ def mode(self, saveMode):
+ """
+ Specifies the behavior when data or table already exists. Options include:
+
+ * `append`: Append contents of this :class:`DataFrame` to existing data.
+ * `overwrite`: Overwrite existing data.
+ * `error`: Throw an exception if data already exists.
+ * `ignore`: Silently ignore this operation if data already exists.
+ """
+ self._jwrite = self._jwrite.mode(saveMode)
+ return self
+
+ @since(1.4)
+ def format(self, source):
+ """
+ Specifies the underlying output data source. Built-in options include
+ "parquet", "json", etc.
+ """
+ self._jwrite = self._jwrite.format(source)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """
+ Adds output options for the underlying data source.
+ """
+ for k in options:
+ self._jwrite = self._jwrite.option(k, options[k])
+ return self
+
+ @since(1.4)
+ def partitionBy(self, *cols):
+ """
+ Partitions the output by the given columns on the file system.
+ If specified, the output is laid out on the file system similar
+ to Hive's partitioning scheme.
+
+ :param cols: name of columns
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+ cols = cols[0]
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ return self
+
@since(1.4)
def save(self, path=None, format=None, mode="error", **options):
"""
@@ -216,16 +286,15 @@ def save(self, path=None, format=None, mode="error", **options):
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param options: all other string options
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
+ self.format(format)
if path is None:
- jwrite.save()
+ self._jwrite.save()
else:
- jwrite.save(path)
+ self._jwrite.save(path)
+ @since(1.4)
def insertInto(self, tableName, overwrite=False):
"""
Inserts the content of the :class:`DataFrame` to the specified table.
@@ -256,12 +325,10 @@ def saveAsTable(self, name, format=None, mode="error", **options):
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param options: all other string options
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
- return jwrite.saveAsTable(name)
+ self.format(format)
+ return self._jwrite.saveAsTable(name)
@since(1.4)
def json(self, path, mode="error"):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5c53c3a8ed4f1..76384d31f1bf4 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -100,6 +100,15 @@ def test_data_type_eq(self):
lt2 = pickle.loads(pickle.dumps(LongType()))
self.assertEquals(lt, lt2)
+ # regression test for SPARK-7978
+ def test_decimal_type(self):
+ t1 = DecimalType()
+ t2 = DecimalType(10, 2)
+ self.assertTrue(t2 is not t1)
+ self.assertNotEqual(t1, t2)
+ t3 = DecimalType(8)
+ self.assertNotEqual(t2, t3)
+
class SQLTests(ReusedPySparkTestCase):
diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py
similarity index 99%
rename from python/pyspark/sql/_types.py
rename to python/pyspark/sql/types.py
index 9e7e9f04bc35d..b6ec6137c9180 100644
--- a/python/pyspark/sql/_types.py
+++ b/python/pyspark/sql/types.py
@@ -97,8 +97,6 @@ class AtomicType(DataType):
"""An internal type used to represent everything that is not
null, UDTs, arrays, structs, and maps."""
- __metaclass__ = DataTypeSingleton
-
class NumericType(AtomicType):
"""Numeric data types.
@@ -109,6 +107,8 @@ class IntegralType(NumericType):
"""Integral data types.
"""
+ __metaclass__ = DataTypeSingleton
+
class FractionalType(NumericType):
"""Fractional data types.
@@ -119,26 +119,36 @@ class StringType(AtomicType):
"""String data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class BinaryType(AtomicType):
"""Binary (byte array) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class BooleanType(AtomicType):
"""Boolean data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class DateType(AtomicType):
"""Date (datetime.date) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class TimestampType(AtomicType):
"""Timestamp (datetime.datetime) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
@@ -172,11 +182,15 @@ class DoubleType(FractionalType):
"""Double data type, representing double precision floats.
"""
+ __metaclass__ = DataTypeSingleton
+
class FloatType(FractionalType):
"""Float data type, representing single precision floats.
"""
+ __metaclass__ = DataTypeSingleton
+
class ByteType(IntegralType):
"""Byte data type, i.e. a signed integer in a single byte.
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 33ea8c9293d74..46cb18b2e8ef9 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -41,8 +41,8 @@
class PySparkStreamingTestCase(unittest.TestCase):
- timeout = 4 # seconds
- duration = .2
+ timeout = 10 # seconds
+ duration = .5
@classmethod
def setUpClass(cls):
@@ -379,13 +379,13 @@ def func(dstream):
class WindowFunctionTests(PySparkStreamingTestCase):
- timeout = 5
+ timeout = 15
def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.window(.6, .2).count()
+ return dstream.window(1.5, .5).count()
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -394,7 +394,7 @@ def test_count_by_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.countByWindow(.6, .2)
+ return dstream.countByWindow(1.5, .5)
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -403,7 +403,7 @@ def test_count_by_window_large(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByWindow(1, .2)
+ return dstream.countByWindow(2.5, .5)
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)
@@ -412,7 +412,7 @@ def test_count_by_value_and_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByValueAndWindow(1, .2)
+ return dstream.countByValueAndWindow(2.5, .5)
expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
self._test_func(input, func, expected)
@@ -421,7 +421,7 @@ def test_group_by_key_and_window(self):
input = [[('a', i)] for i in range(5)]
def func(dstream):
- return dstream.groupByKeyAndWindow(.6, .2).mapValues(list)
+ return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list)
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
diff --git a/python/run-tests b/python/run-tests
index ffde2fb24b369..17dda3eadac0c 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -57,54 +57,56 @@ function run_test() {
function run_core_tests() {
echo "Run core tests ..."
- run_test "pyspark/rdd.py"
- run_test "pyspark/context.py"
- run_test "pyspark/conf.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
- run_test "pyspark/serializers.py"
- run_test "pyspark/profiler.py"
- run_test "pyspark/shuffle.py"
- run_test "pyspark/tests.py"
+ run_test "pyspark.rdd"
+ run_test "pyspark.context"
+ run_test "pyspark.conf"
+ run_test "pyspark.broadcast"
+ run_test "pyspark.accumulators"
+ run_test "pyspark.serializers"
+ run_test "pyspark.profiler"
+ run_test "pyspark.shuffle"
+ run_test "pyspark.tests"
}
function run_sql_tests() {
echo "Run sql tests ..."
- run_test "pyspark/sql/_types.py"
- run_test "pyspark/sql/context.py"
- run_test "pyspark/sql/column.py"
- run_test "pyspark/sql/dataframe.py"
- run_test "pyspark/sql/group.py"
- run_test "pyspark/sql/functions.py"
- run_test "pyspark/sql/tests.py"
+ run_test "pyspark.sql.types"
+ run_test "pyspark.sql.context"
+ run_test "pyspark.sql.column"
+ run_test "pyspark.sql.dataframe"
+ run_test "pyspark.sql.group"
+ run_test "pyspark.sql.functions"
+ run_test "pyspark.sql.readwriter"
+ run_test "pyspark.sql.window"
+ run_test "pyspark.sql.tests"
}
function run_mllib_tests() {
echo "Run mllib tests ..."
- run_test "pyspark/mllib/classification.py"
- run_test "pyspark/mllib/clustering.py"
- run_test "pyspark/mllib/evaluation.py"
- run_test "pyspark/mllib/feature.py"
- run_test "pyspark/mllib/fpm.py"
- run_test "pyspark/mllib/linalg.py"
- run_test "pyspark/mllib/rand.py"
- run_test "pyspark/mllib/recommendation.py"
- run_test "pyspark/mllib/regression.py"
- run_test "pyspark/mllib/stat/_statistics.py"
- run_test "pyspark/mllib/tree.py"
- run_test "pyspark/mllib/util.py"
- run_test "pyspark/mllib/tests.py"
+ run_test "pyspark.mllib.classification"
+ run_test "pyspark.mllib.clustering"
+ run_test "pyspark.mllib.evaluation"
+ run_test "pyspark.mllib.feature"
+ run_test "pyspark.mllib.fpm"
+ run_test "pyspark.mllib.linalg"
+ run_test "pyspark.mllib.random"
+ run_test "pyspark.mllib.recommendation"
+ run_test "pyspark.mllib.regression"
+ run_test "pyspark.mllib.stat._statistics"
+ run_test "pyspark.mllib.tree"
+ run_test "pyspark.mllib.util"
+ run_test "pyspark.mllib.tests"
}
function run_ml_tests() {
echo "Run ml tests ..."
- run_test "pyspark/ml/feature.py"
- run_test "pyspark/ml/classification.py"
- run_test "pyspark/ml/recommendation.py"
- run_test "pyspark/ml/regression.py"
- run_test "pyspark/ml/tuning.py"
- run_test "pyspark/ml/tests.py"
- run_test "pyspark/ml/evaluation.py"
+ run_test "pyspark.ml.feature"
+ run_test "pyspark.ml.classification"
+ run_test "pyspark.ml.recommendation"
+ run_test "pyspark.ml.regression"
+ run_test "pyspark.ml.tuning"
+ run_test "pyspark.ml.tests"
+ run_test "pyspark.ml.evaluation"
}
function run_streaming_tests() {
@@ -124,8 +126,8 @@ function run_streaming_tests() {
done
export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
- run_test "pyspark/streaming/util.py"
- run_test "pyspark/streaming/tests.py"
+ run_test "pyspark.streaming.util"
+ run_test "pyspark.streaming.tests"
}
echo "Running PySpark tests. Output is in python/$LOG_FILE."
diff --git a/repl/pom.xml b/repl/pom.xml
index 03053b4c3b287..6e5cb7f77e1df 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -48,6 +48,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-bagel_${scala.binary.version}
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 934daaeaafca1..50fd43a418bca 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -22,13 +22,12 @@ import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.util.Utils
-class ReplSuite extends FunSuite {
+class ReplSuite extends SparkFunSuite {
def runInterpreter(master: String, input: String): String = {
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 14f5e9ed4f25e..9ecc7c229e38a 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.tools.nsc.interpreter.SparkILoop
-import org.scalatest.FunSuite
import org.apache.commons.lang3.StringEscapeUtils
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils
-class ReplSuite extends FunSuite {
+class ReplSuite extends SparkFunSuite {
def runInterpreter(master: String, input: String): String = {
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index c709cde740748..a58eda12b1120 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -25,7 +25,6 @@ import scala.language.implicitConversions
import scala.language.postfixOps
import org.scalatest.BeforeAndAfterAll
-import org.scalatest.FunSuite
import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar
@@ -35,7 +34,7 @@ import org.apache.spark._
import org.apache.spark.util.Utils
class ExecutorClassLoaderSuite
- extends FunSuite
+ extends SparkFunSuite
with BeforeAndAfterAll
with MockitoSugar
with Logging {
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 68d980b610c00..d6f927b6fa803 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -14,25 +14,41 @@
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
-
-
-
-
-
-
+
- Scalastyle standard configuration
-
-
-
-
-
-
-
-
- Scalastyle standard configuration
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
- true
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
- ARROW, EQUALS
+ ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW
-
+
+
- ARROW, EQUALS, COMMA, COLON, IF, WHILE, FOR
+ ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+ ^FunSuite[A-Za-z]*$
+ Tests must extend org.apache.spark.SparkFunSuite instead.
+
+
+
+
+
+
+
+
+ ^println$
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 800>
+
+
+
+
+ 30
+
+
+
+
+ 10
+
+
+
+
+ 50
+
+
+
+
+
+
+
+
+
+
+ -1,0,1,2,3
+
+
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5c322d032d474..d9e1cdb84bb27 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -50,6 +50,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-unsafe_${scala.binary.version}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 75a493b248f6e..1c0ddb5093d17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -233,7 +233,7 @@ object CatalystTypeConverters {
case other => other
}
- /**
+ /**
* Converts Catalyst types used internally in rows to standard Scala types
* This method is slow, and for batch conversion you should be using converter
* produced by createToScalaConverter.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index df37889eedcf0..bc17169f35a46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -633,10 +633,10 @@ class Analyzer(
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
- def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
+ private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
projectList.exists(hasWindowFunction)
- def hasWindowFunction(expr: NamedExpression): Boolean = {
+ private def hasWindowFunction(expr: NamedExpression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
@@ -644,14 +644,24 @@ class Analyzer(
}
/**
- * From a Seq of [[NamedExpression]]s, extract window expressions and
- * other regular expressions.
+ * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and
+ * other regular expressions that do not contain any window expression. For example, for
+ * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract
+ * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in
+ * the window expression as attribute references. So, the first returned value will be
+ * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be
+ * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2].
+ *
+ * @return (seq of expressions containing at lease one window expressions,
+ * seq of non-window expressions)
*/
- def extract(
+ private def extract(
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
- // First, we simple partition the input expressions to two part, one having
- // WindowExpressions and another one without WindowExpressions.
- val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)
+ // First, we partition the input expressions to two part. For the first part,
+ // every expression in it contain at least one WindowExpression.
+ // Expressions in the second part do not have any WindowExpression.
+ val (expressionsWithWindowFunctions, regularExpressions) =
+ expressions.partition(hasWindowFunction)
// Then, we need to extract those regular expressions used in the WindowExpression.
// For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
@@ -660,8 +670,8 @@ class Analyzer(
val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
def extractExpr(expr: Expression): Expression = expr match {
case ne: NamedExpression =>
- // If a named expression is not in regularExpressions, add extract it and replace it
- // with an AttributeReference.
+ // If a named expression is not in regularExpressions, add it to
+ // extractedExprBuffer and replace it with an AttributeReference.
val missingExpr =
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
if (missingExpr.nonEmpty) {
@@ -678,8 +688,9 @@ class Analyzer(
withName.toAttribute
}
- // Now, we extract expressions from windowExpressions by using extractExpr.
- val newWindowExpressions = windowExpressions.map {
+ // Now, we extract regular expressions from expressionsWithWindowFunctions
+ // by using extractExpr.
+ val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
_.transform {
// Extracts children expressions of a WindowFunction (input parameters of
// a WindowFunction).
@@ -705,37 +716,80 @@ class Analyzer(
}.asInstanceOf[NamedExpression]
}
- (newWindowExpressions, regularExpressions ++ extractedExprBuffer)
- }
+ (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer)
+ } // end of extract
/**
* Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
*/
- def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
- // First, we group window expressions based on their Window Spec.
- val groupedWindowExpression = windowExpressions.groupBy { expr =>
- val windowSpec = expr.collectFirst {
+ private def addWindow(
+ expressionsWithWindowFunctions: Seq[NamedExpression],
+ child: LogicalPlan): LogicalPlan = {
+ // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions
+ // and put those extracted WindowExpressions to extractedWindowExprBuffer.
+ // This step is needed because it is possible that an expression contains multiple
+ // WindowExpressions with different Window Specs.
+ // After extracting WindowExpressions, we need to construct a project list to generate
+ // expressionsWithWindowFunctions based on extractedWindowExprBuffer.
+ // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract
+ // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to
+ // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)".
+ // Then, the projectList will be [_we0/_we1].
+ val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]()
+ val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
+ // We need to use transformDown because we want to trigger
+ // "case alias @ Alias(window: WindowExpression, _)" first.
+ _.transformDown {
+ case alias @ Alias(window: WindowExpression, _) =>
+ // If a WindowExpression has an assigned alias, just use it.
+ extractedWindowExprBuffer += alias
+ alias.toAttribute
+ case window: WindowExpression =>
+ // If there is no alias assigned to the WindowExpressions. We create an
+ // internal column.
+ val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")()
+ extractedWindowExprBuffer += withName
+ withName.toAttribute
+ }.asInstanceOf[NamedExpression]
+ }
+
+ // Second, we group extractedWindowExprBuffer based on their Window Spec.
+ val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr =>
+ val distinctWindowSpec = expr.collect {
case window: WindowExpression => window.windowSpec
+ }.distinct
+
+ // We do a final check and see if we only have a single Window Spec defined in an
+ // expressions.
+ if (distinctWindowSpec.length == 0 ) {
+ failAnalysis(s"$expr does not have any WindowExpression.")
+ } else if (distinctWindowSpec.length > 1) {
+ // newExpressionsWithWindowFunctions only have expressions with a single
+ // WindowExpression. If we reach here, we have a bug.
+ failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." +
+ s"Please file a bug report with this error message, stack trace, and the query.")
+ } else {
+ distinctWindowSpec.head
}
- windowSpec.getOrElse(
- failAnalysis(s"$windowExpressions does not have any WindowExpression."))
}.toSeq
- // For every Window Spec, we add a Window operator and set currentChild as the child of it.
+ // Third, for every Window Spec, we add a Window operator and set currentChild as the
+ // child of it.
var currentChild = child
var i = 0
- while (i < groupedWindowExpression.size) {
- val (windowSpec, windowExpressions) = groupedWindowExpression(i)
+ while (i < groupedWindowExpressions.size) {
+ val (windowSpec, windowExpressions) = groupedWindowExpressions(i)
// Set currentChild to the newly created Window operator.
currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)
- // Move to next WindowExpression.
+ // Move to next Window Spec.
i += 1
}
- // We return the top operator.
- currentChild
- }
+ // Finally, we create a Project to output currentChild's output
+ // newExpressionsWithWindowFunctions.
+ Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild)
+ } // end of addWindow
// We have to use transformDown at here to make sure the rule of
// "Aggregate with Having clause" will be triggered.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 44664f898f762..edcc918bfe921 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -76,8 +76,7 @@ trait HiveTypeCoercion {
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
- BooleanComparisons ::
- BooleanCasts ::
+ BooleanEqualization ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
@@ -120,7 +119,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
- val stringNaN = Literal("NaN")
+ private val stringNaN = Literal("NaN")
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
@@ -297,8 +296,8 @@ trait HiveTypeCoercion {
object InConversion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
- case e if !e.childrenResolved => e
-
+ case e if !e.childrenResolved => e
+
case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
i.makeCopy(Array(a, b.map(Cast(_, a.dataType))))
}
@@ -350,17 +349,17 @@ trait HiveTypeCoercion {
import scala.math.{max, min}
// Conversion rules for integer types into fixed-precision decimals
- val intTypeToFixed: Map[DataType, DecimalType] = Map(
+ private val intTypeToFixed: Map[DataType, DecimalType] = Map(
ByteType -> DecimalType(3, 0),
ShortType -> DecimalType(5, 0),
IntegerType -> DecimalType(10, 0),
LongType -> DecimalType(20, 0)
)
- def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+ private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
// Conversion rules for float and double into fixed-precision decimals
- val floatTypeToFixed: Map[DataType, DecimalType] = Map(
+ private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
FloatType -> DecimalType(7, 7),
DoubleType -> DecimalType(15, 15)
)
@@ -483,56 +482,66 @@ trait HiveTypeCoercion {
}
/**
- * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
+ * Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
- object BooleanComparisons extends Rule[LogicalPlan] {
- val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_))
- val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_))
+ object BooleanEqualization extends Rule[LogicalPlan] {
+ private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
+ private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
+
+ private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
+ CaseKeyWhen(numericExpr, Seq(
+ Literal(trueValues.head), booleanExpr,
+ Literal(falseValues.head), Not(booleanExpr),
+ Literal(false)))
+ }
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- // Skip nodes who's children have not been resolved yet.
- case e if !e.childrenResolved => e
+ private def transform(booleanExpr: Expression, numericExpr: Expression) = {
+ If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
+ Literal.create(null, BooleanType),
+ buildCaseKeyWhen(booleanExpr, numericExpr))
+ }
- // Hive treats (true = 1) as true and (false = 0) as true.
- case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
- case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
- case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
- case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
-
- // No need to change other EqualTo operators as that actually makes sense for boolean types.
- case e: EqualTo => e
- // No need to change the EqualNullSafe operators, too
- case e: EqualNullSafe => e
- // Otherwise turn them to Byte types so that there exists and ordering.
- case p: BinaryComparison if p.left.dataType == BooleanType &&
- p.right.dataType == BooleanType =>
- p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
+ private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {
+ CaseWhen(Seq(
+ And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true),
+ Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false),
+ buildCaseKeyWhen(booleanExpr, numericExpr)
+ ))
}
- }
- /**
- * Casts to/from [[BooleanType]] are transformed into comparisons since
- * the JVM does not consider Booleans to be numeric types.
- */
- object BooleanCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- // Skip if the type is boolean type already. Note that this extra cast should be removed
- // by optimizer.SimplifyCasts.
- case Cast(e, BooleanType) if e.dataType == BooleanType => e
- // DateType should be null if be cast to boolean.
- case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
- // If the data type is not boolean and is being cast boolean, turn it into a comparison
- // with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
- case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
- // Stringify boolean if casting to StringType.
- // TODO Ensure true/false string letter casing is consistent with Hive in all cases.
- case Cast(e, StringType) if e.dataType == BooleanType =>
- If(e, Literal("true"), Literal("false"))
- // Turn true into 1, and false into 0 if casting boolean into other types.
- case Cast(e, dataType) if e.dataType == BooleanType =>
- Cast(If(e, Literal(1), Literal(0)), dataType)
+
+ // Hive treats (true = 1) as true and (false = 0) as true,
+ // all other cases are considered as false.
+
+ // We may simplify the expression if one side is literal numeric values
+ case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => l
+ case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => Not(l)
+ case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
+ if trueValues.contains(value) => r
+ case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
+ if falseValues.contains(value) => Not(r)
+ case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
+ if trueValues.contains(value) => And(IsNotNull(l), l)
+ case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
+ if falseValues.contains(value) => And(IsNotNull(l), Not(l))
+ case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
+ if trueValues.contains(value) => And(IsNotNull(r), r)
+ case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
+ if falseValues.contains(value) => And(IsNotNull(r), Not(r))
+
+ case EqualTo(l @ BooleanType(), r @ NumericType()) =>
+ transform(l , r)
+ case EqualTo(l @ NumericType(), r @ BooleanType()) =>
+ transform(r, l)
+ case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
+ transformNullSafe(l, r)
+ case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
+ transformNullSafe(r, l)
}
}
@@ -633,7 +642,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
+ case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
@@ -652,6 +661,23 @@ trait HiveTypeCoercion {
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}
+
+ case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
+ val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
+ findTightestCommonType(v1, v2).getOrElse(sys.error(
+ s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
+ }
+ val transformedBranches = ckw.branches.sliding(2, 2).map {
+ case Seq(when, then) if when.dataType != commonType =>
+ Seq(Cast(when, commonType), then)
+ case s => s
+ }.reduce(_ ++ _)
+ val transformedKey = if (ckw.key.dataType != commonType) {
+ Cast(ckw.key, commonType)
+ } else {
+ ckw.key
+ }
+ CaseKeyWhen(transformedKey, transformedBranches)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index d19928784442e..adc6505d69cdf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -37,7 +37,15 @@ abstract class Expression extends TreeNode[Expression] {
* - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable
*/
def foldable: Boolean = false
+
+ /**
+ * Returns true when the current expression always return the same result for fixed input values.
+ */
+ // TODO: Need to define explicit input values vs implicit input values.
+ def deterministic: Boolean = true
+
def nullable: Boolean
+
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
/** Returns the result of evaluating this expression on a given input Row */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 195eec8e5cdc4..99340a14c9ecc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -29,7 +29,7 @@ case object Descending extends SortDirection
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
*/
-case class SortOrder(child: Expression, direction: SortDirection) extends Expression
+case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {
override def dataType: DataType = child.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 6c380d3084652..0266084a6d174 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -394,13 +394,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
* Combining PartitionLevel InputData
* <-- null
* Zero <-- Zero <-- null
- *
+ *
* <-- null <-- no data
- * null <-- null <-- no data
+ * null <-- null <-- no data
*/
case class CombineSum(child: Expression) extends AggregateExpression {
def this() = this(null)
-
+
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
@@ -616,7 +616,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
private val sum = MutableLiteral(null, calcType)
- private val addFunction =
+ private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: Row): Unit = {
@@ -634,7 +634,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
-
+
def this() = this(null, null) // Required for serialization.
private val calcType =
@@ -649,12 +649,12 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
private val sum = MutableLiteral(null, calcType)
- private val addFunction =
+ private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
-
+
override def update(input: Row): Unit = {
val result = expr.eval(input)
- // partial sum result can be null only when no input rows present
+ // partial sum result can be null only when no input rows present
if(result != null) {
sum.update(addFunction, input)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 34c833b260dc0..f2299d5db6e9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -180,7 +180,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
case other => sys.error(s"Type $other does not support numeric operations")
}
-
+
override def eval(input: Row): Any = {
val evalE2 = right.eval(input)
if (evalE2 == null || evalE2 == 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index e7cd7131a9e56..6398b8f9e4ed7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.types._
case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
-
+
lazy val childTypes = children.map(_.dataType).distinct
override lazy val resolved =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
index 890efc9f52ca3..01f62ba0442e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.types._
* @param f The math function.
* @param name The short name of the function
*/
-abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
+abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product =>
override def symbol: String = null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index e2d1c8115e051..4f422d69c4382 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression {
// both then and else val should be considered.
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
- def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
+ def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
override def dataType: DataType = {
if (!resolved) {
@@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
override def children: Seq[Expression] = key +: branches
override lazy val resolved: Boolean =
- childrenResolved && valueTypesEqual
+ childrenResolved && valueTypesEqual &&
+ (key +: whenList).map(_.dataType).distinct.size == 1
/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index de82c15680607..b2647124c4e49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -24,7 +24,7 @@ import org.apache.spark.util.random.XORShiftRandom
/**
* A Random distribution generating expression.
- * TODO: This can be made generic to generate any type of random distribution, or any type of
+ * TODO: This can be made generic to generate any type of random distribution, or any type of
* StructType.
*
* Since this expression is stateful, it cannot be a case object.
@@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
*/
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
+ override def deterministic: Boolean = false
+
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 83a44a12f0682..c4ef9c30907f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -133,7 +133,7 @@ trait CaseConversionExpression extends ExpectsInputTypes {
* A function that converts the characters of a string to uppercase.
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
-
+
override def convert(v: UTF8String): UTF8String = v.toUpperCase()
override def toString: String = s"Upper($child)"
@@ -143,7 +143,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
* A function that converts the characters of a string to lowercase.
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
-
+
override def convert(v: UTF8String): UTF8String = v.toLowerCase()
override def toString: String = s"Lower($child)"
@@ -223,7 +223,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
@inline
def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
// Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
- // negative indices for start positions. If a start index i is greater than 0, it
+ // negative indices for start positions. If a start index i is greater than 0, it
// refers to element i-1 in the sequence. If a start index i is less than 0, it refers
// to the -ith element before the end of the sequence. If a start index i is 0, it
// refers to the first element.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c2818d957cc79..b25fb48f55e2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -179,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] {
* expressions into one single expression.
*/
object ProjectCollapsing extends Rule[LogicalPlan] {
+
+ /** Returns true if any expression in projectList is non-deterministic. */
+ private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = {
+ projectList.exists(expr => expr.find(!_.deterministic).isDefined)
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case Project(projectList1, Project(projectList2, child)) =>
+ // We only collapse these two Projects if the child Project's expressions are all
+ // deterministic.
+ case Project(projectList1, Project(projectList2, child))
+ if !hasNondeterministic(projectList2) =>
// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 54604808e133e..1ba3a2686639f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -165,6 +165,9 @@ object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
+ /**
+ * @deprecated As of 1.2.0, replaced by `DataType.fromJson()`
+ */
@deprecated("Use DataType.fromJson instead", "1.2.0")
def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 994c5202c15dc..eb3c58c37f308 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -313,7 +313,7 @@ object Decimal {
// See scala.math's Numeric.scala for examples for Scala's built-in types.
/** Common methods for Decimal evidence parameters */
- trait DecimalIsConflicted extends Numeric[Decimal] {
+ private[sql] trait DecimalIsConflicted extends Numeric[Decimal] {
override def plus(x: Decimal, y: Decimal): Decimal = x + y
override def times(x: Decimal, y: Decimal): Decimal = x * y
override def minus(x: Decimal, y: Decimal): Decimal = x - y
@@ -327,12 +327,12 @@ object Decimal {
}
/** A [[scala.math.Fractional]] evidence parameter for Decimals. */
- object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] {
+ private[sql] object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] {
override def div(x: Decimal, y: Decimal): Decimal = x / y
}
/** A [[scala.math.Integral]] evidence parameter for Decimals. */
- object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] {
+ private[sql] object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] {
override def quot(x: Decimal, y: Decimal): Decimal = x / y
override def rem(x: Decimal, y: Decimal): Decimal = x % y
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 0f8cecd28f7df..407dc27326c2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -82,12 +82,12 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
object DecimalType {
val Unlimited: DecimalType = DecimalType(None)
- object Fixed {
+ private[sql] object Fixed {
def unapply(t: DecimalType): Option[(Int, Int)] =
t.precisionInfo.map(p => (p.precision, p.scale))
}
- object Expression {
+ private[sql] object Expression {
def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
case _ => None
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java
index a64d2bb7cde37..df64a878b6b36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java
@@ -24,11 +24,11 @@
/**
* ::DeveloperApi::
* A user-defined type which can be automatically recognized by a SQLContext and registered.
- *
+ *
* WARNING: This annotation will only work if both Java and Scala reflection return the same class
* names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class
* is enclosed in an object (a singleton).
- *
+ *
* WARNING: UDTs are currently only supported from Scala.
*/
// TODO: Should I used @Documented ?
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index a4f30c825befb..193c08a4d0df7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -265,7 +265,7 @@ object StructType {
case _ =>
throw new SparkException(s"Failed to merge incompatible data types $left and $right")
}
-
+
private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = {
import scala.collection.breakOut
fields.map(s => (s.name, s))(breakOut)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
index bc9c37bf2d5d2..f5d8fcced362b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
@@ -203,7 +203,7 @@ object UTF8String {
def apply(s: String): UTF8String = {
if (s != null) {
new UTF8String().set(s)
- } else{
+ } else {
null
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index ea82cd2622de9..c046dbf4dc2c9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.catalyst
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.physical._
/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
-class DistributionSuite extends FunSuite {
+class DistributionSuite extends SparkFunSuite {
protected def checkSatisfied(
inputPartitioning: Partitioning,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 7ff51db76b6bb..9a24b23024e18 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst
import java.math.BigInteger
import java.sql.{Date, Timestamp}
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.types._
@@ -75,7 +74,7 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) {
def this(b: String, a: Int) = this(a, b, c = 1.0)
}
-class ScalaReflectionSuite extends FunSuite {
+class ScalaReflectionSuite extends SparkFunSuite {
import ScalaReflection._
test("primitive data") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
index 9eed15952d82b..b93a3abc6ebd2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.Command
-import org.scalatest.FunSuite
private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
override def output: Seq[Attribute] = Seq.empty
@@ -49,7 +49,7 @@ private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
}
}
-class SqlParserSuite extends FunSuite {
+class SqlParserSuite extends SparkFunSuite {
test("test long keyword") {
val parser = new SuperLongKeywordTestParser
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index fcff24ca31486..e09cd790a7187 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.catalyst.analysis
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-class AnalysisSuite extends FunSuite with BeforeAndAfter {
+class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
val caseSensitiveConf = new SimpleCatalystConf(true)
val caseInsensitiveConf = new SimpleCatalystConf(false)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 565b1cfe019c7..1b8d18ded2257 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -17,14 +17,15 @@
package org.apache.spark.sql.catalyst.analysis
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
-class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
+class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
val conf = new SimpleCatalystConf(true)
val catalog = new SimpleCatalog(conf)
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index fcd745f43cfbf..a0798428db094 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
class HiveTypeCoercionSuite extends PlanTest {
@@ -104,31 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}
- test("boolean casts") {
- val booleanCasts = new HiveTypeCoercion { }.BooleanCasts
- def ruleTest(initial: Expression, transformed: Expression) {
- val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
- comparePlans(
- booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)),
- Project(Seq(Alias(transformed, "a")()), testRelation))
- }
- // Remove superflous boolean -> boolean casts.
- ruleTest(Cast(Literal(true), BooleanType), Literal(true))
- // Stringify boolean when casting to string.
- ruleTest(
- Cast(Literal(false), StringType),
- If(Literal(false), Literal("true"), Literal("false")))
+ private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
+ val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+ comparePlans(
+ rule(Project(Seq(Alias(initial, "a")()), testRelation)),
+ Project(Seq(Alias(transformed, "a")()), testRelation))
}
test("coalesce casts") {
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
- def ruleTest(initial: Expression, transformed: Expression) {
- val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
- comparePlans(
- fac(Project(Seq(Alias(initial, "a")()), testRelation)),
- Project(Seq(Alias(transformed, "a")()), testRelation))
- }
- ruleTest(
+ ruleTest(fac,
Coalesce(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
@@ -137,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
- ruleTest(
+ ruleTest(fac,
Coalesce(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
@@ -147,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
:: Nil))
}
+
+ test("type coercion for CaseKeyWhen") {
+ val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
+ ruleTest(cwc,
+ CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
+ CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
+ )
+ // Will remove exception expectation in PR#6405
+ intercept[RuntimeException] {
+ ruleTest(cwc,
+ CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
+ CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
+ )
+ }
+ }
+
+ test("type coercion simplification for equal to") {
+ val be = new HiveTypeCoercion {}.BooleanEqualization
+ ruleTest(be,
+ EqualTo(Literal(true), Literal(1)),
+ Literal(true)
+ )
+ ruleTest(be,
+ EqualTo(Literal(true), Literal(0)),
+ Not(Literal(true))
+ )
+ ruleTest(be,
+ EqualNullSafe(Literal(true), Literal(1)),
+ And(IsNotNull(Literal(true)), Literal(true))
+ )
+ ruleTest(be,
+ EqualNullSafe(Literal(true), Literal(0)),
+ And(IsNotNull(Literal(true)), Not(Literal(true)))
+ )
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
index f2f3a84d19380..97cfb5f06dd73 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType
-class AttributeSetSuite extends FunSuite {
+class AttributeSetSuite extends SparkFunSuite {
val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index a14f776b1eaee..b6927485f42bf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -22,9 +22,9 @@ import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet
import org.scalactic.TripleEqualsSupport.Spread
-import org.scalatest.FunSuite
import org.scalatest.Matchers._
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
-class ExpressionEvaluationBaseSuite extends FunSuite {
+class ExpressionEvaluationBaseSuite extends SparkFunSuite {
def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
expression.eval(inputRow)
@@ -372,6 +372,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0)
checkEvaluation(Literal(true) cast IntegerType, 1)
checkEvaluation(Literal(false) cast IntegerType, 0)
+ checkEvaluation(Literal(true) cast StringType, "true")
+ checkEvaluation(Literal(false) cast StringType, "false")
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0)
checkEvaluation("23" cast DoubleType, 23d)
@@ -860,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)
- val literalNull = Literal.create(null, BooleanType)
+ val literalNull = Literal.create(null, IntegerType)
val literalInt = Literal(1)
val literalString = Literal("a")
@@ -869,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
- checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
+ checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row)
checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
- checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
- checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
+ checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row)
+ checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
}
test("complex type") {
@@ -1207,7 +1209,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
}
/**
- * Used for testing math functions for DataFrames.
+ * Used for testing math functions for DataFrames.
* @param c The DataFrame function
* @param f The functions in scala.math
* @param domain The set of values to run the function with
@@ -1215,7 +1217,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
* @tparam T Generic type for primitives
*/
def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T](
- c: Expression => Expression,
+ c: Expression => Expression,
f: T => T,
domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
expectNull: Boolean = false): Unit = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 7a19e511eb8b5..88a36aa121b55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.JavaConverters._
import scala.util.Random
+import org.apache.spark.SparkFunSuite
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
-import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfterEach, Matchers}
import org.apache.spark.sql.types._
-class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach {
+class UnsafeFixedWidthAggregationMapSuite
+ extends SparkFunSuite
+ with Matchers
+ with BeforeAndAfterEach {
import UnsafeFixedWidthAggregationMap._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 3a60c7fd32675..61722f1ffa462 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Arrays
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
-class UnsafeRowConverterSuite extends FunSuite with Matchers {
+class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
test("basic conversion with only primitive types") {
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index a30052b38fc11..06c592f4905a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -71,7 +71,7 @@ class CombiningLimitsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
-
+
test("limits: combines two limits after ColumnPruning") {
val originalQuery =
testRelation
@@ -79,7 +79,7 @@ class CombiningLimitsSuite extends PlanTest {
.limit(2)
.select('a)
.limit(5)
-
+
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 5697c2272b8e8..ec3b2f1edfa05 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -248,7 +248,7 @@ class ConstantFoldingSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
-
+
test("Constant folding test: Fold In(v, list) into true or false") {
var originalQuery =
testRelation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ff25470bf0946..17dc9124749e8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -93,7 +93,7 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
-
+
test("column pruning for Project(ne, Limit)") {
val originalQuery =
testRelation
@@ -109,7 +109,7 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
-
+
// After this line is unimplemented.
test("simple push down") {
val originalQuery =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 11b0859d3f066..1d433275fed2e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -57,7 +57,7 @@ class OptimizeInSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
-
+
test("OptimizedIn test: In clause not optimized in case filter has attributes") {
val originalQuery =
testRelation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
new file mode 100644
index 0000000000000..151654bffbd66
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Rand
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+
+class ProjectCollapsingSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
+ Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int)
+
+ test("collapse two deterministic, independent projects into one") {
+ val query = testRelation
+ .select(('a + 1).as('a_plus_1), 'b)
+ .select('a_plus_1, ('b + 1).as('b_plus_1))
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("collapse two deterministic, dependent projects into one") {
+ val query = testRelation
+ .select(('a + 1).as('a_plus_1), 'b)
+ .select(('a_plus_1 + 1).as('a_plus_2), 'b)
+
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer = testRelation.select(
+ (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
+ 'b).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("do not collapse nondeterministic projects") {
+ val query = testRelation
+ .select(Rand(10).as('rand))
+ .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = query.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index e7cafcc96de87..765c1e2dda99f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.plans
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.util._
@@ -26,7 +25,7 @@ import org.apache.spark.sql.catalyst.util._
/**
* Provides helper methods for comparing plans.
*/
-class PlanTest extends FunSuite {
+class PlanTest extends SparkFunSuite {
/**
* Since attribute references are given globally unique ids during analysis,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
index 1273921f6394c..62d5f6ac74885 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.catalyst.plans
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
@@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util._
/**
* Tests for the sameResult function of [[LogicalPlan]].
*/
-class SameResultSuite extends FunSuite {
+class SameResultSuite extends SparkFunSuite {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
index 2a641c63f87bb..a7de7b052bdc3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.catalyst.trees
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
-class RuleExecutorSuite extends FunSuite {
+class RuleExecutorSuite extends SparkFunSuite {
object DecrementLiterals extends Rule[Expression] {
def apply(e: Expression): Expression = e transform {
case IntegerLiteral(i) if i > 0 => Literal(i - 1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 9fcfc51c96139..67db3d5e6d751 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.trees
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{IntegerType, StringType, NullType}
@@ -32,7 +31,7 @@ case class Dummy(optKey: Option[Expression]) extends Expression {
override def eval(input: Row): Any = null.asInstanceOf[Any]
}
-class TreeNodeSuite extends FunSuite {
+class TreeNodeSuite extends SparkFunSuite {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
assert(after === Literal(2))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
index d7d60efee50fa..4030a1b1df358 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
@@ -18,11 +18,11 @@
package org.apache.spark.sql.catalyst.util
import org.json4s.jackson.JsonMethods.parse
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{MetadataBuilder, Metadata}
-class MetadataSuite extends FunSuite {
+class MetadataSuite extends SparkFunSuite {
val baseMetadata = new MetadataBuilder()
.putString("purpose", "ml")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
index 3e7cf7cbb5e63..c6171b7b6916d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.types
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class DataTypeParserSuite extends FunSuite {
+class DataTypeParserSuite extends SparkFunSuite {
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index df119827812f9..261c4fcad24aa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql.types
-import org.apache.spark.SparkException
-import org.scalatest.FunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
-class DataTypeSuite extends FunSuite {
+class DataTypeSuite extends SparkFunSuite {
test("construct an ArrayType") {
val array = ArrayType(StringType)
@@ -72,7 +71,7 @@ class DataTypeSuite extends FunSuite {
test("fieldsMap returns map of name to StructField") {
val struct = StructType(
- StructField("a", LongType) ::
+ StructField("a", LongType) ::
StructField("b", FloatType) :: Nil)
val mapped = StructType.fieldsMap(struct.fields)
@@ -91,7 +90,7 @@ class DataTypeSuite extends FunSuite {
val right = StructType(List())
val merged = left.merge(right)
-
+
assert(merged === left)
}
@@ -134,7 +133,7 @@ class DataTypeSuite extends FunSuite {
val right = StructType(
StructField("b", LongType) :: Nil)
-
+
intercept[SparkException] {
left.merge(right)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
index a22aa6f244c48..81d7ab010f394 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.types
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
// scalastyle:off
-class UTF8StringSuite extends FunSuite {
+class UTF8StringSuite extends SparkFunSuite {
test("basic") {
def check(str: String, len: Int) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index de6a2cd448c47..28b373e258311 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.types.decimal
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.Decimal
-import org.scalatest.{PrivateMethodTester, FunSuite}
+import org.scalatest.PrivateMethodTester
import scala.language.postfixOps
-class DecimalSuite extends FunSuite with PrivateMethodTester {
+class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
test("creating decimals") {
/** Check that a Decimal has the given string representation, precision and scale */
def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index ffe95bb49188f..8210c552603ea 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -41,6 +41,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-catalyst_${scala.binary.version}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b49b1d327289f..d3efa83380d04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -716,6 +716,18 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/
def endsWith(literal: String): Column = this.endsWith(lit(literal))
+ /**
+ * Gives the column an alias. Same as `as`.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".alias("colB"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 1.4.0
+ */
+ def alias(alias: String): Column = as(alias)
+
/**
* Gives the column an alias.
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index e90109446b642..034d887901975 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -57,14 +57,11 @@ private[sql] object DataFrame {
* :: Experimental ::
* A distributed collection of data organized into named columns.
*
- * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways
- * to create a [[DataFrame]]:
+ * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates
+ * a [[DataFrame]] by pointing Spark SQL to a Parquet data set.
* {{{
- * // Create a DataFrame from Parquet files
- * val people = sqlContext.parquetFile("...")
- *
- * // Create a DataFrame from data sources
- * val df = sqlContext.load("...", "json")
+ * val people = sqlContext.read.parquet("...") // in Scala
+ * DataFrame people = sqlContext.read().parquet("...") // in Java
* }}}
*
* Once created, it can be manipulated using the various domain-specific-language (DSL) functions
@@ -86,8 +83,8 @@ private[sql] object DataFrame {
* A more concrete example in Scala:
* {{{
* // To create DataFrame using SQLContext
- * val people = sqlContext.parquetFile("...")
- * val department = sqlContext.parquetFile("...")
+ * val people = sqlContext.read.parquet("...")
+ * val department = sqlContext.read.parquet("...")
*
* people.filter("age > 30")
* .join(department, people("deptId") === department("id"))
@@ -98,8 +95,8 @@ private[sql] object DataFrame {
* and in Java:
* {{{
* // To create DataFrame using SQLContext
- * DataFrame people = sqlContext.parquetFile("...");
- * DataFrame department = sqlContext.parquetFile("...");
+ * DataFrame people = sqlContext.read().parquet("...");
+ * DataFrame department = sqlContext.read().parquet("...");
*
* people.filter("age".gt(30))
* .join(department, people.col("deptId").equalTo(department("id")))
@@ -1444,7 +1441,9 @@ class DataFrame private[sql](
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
- /** Left here for backward compatibility. */
+ /**
+ * @deprecated As of 1.3.0, replaced by `toDF()`.
+ */
@deprecated("use toDF", "1.3.0")
def toSchemaRDD: DataFrame = this
@@ -1455,6 +1454,7 @@ class DataFrame private[sql](
* given name; if you pass `false`, it will throw if the table already
* exists.
* @group output
+ * @deprecated As of 1.340, replaced by `write().jdbc()`.
*/
@deprecated("Use write.jdbc()", "1.4.0")
def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = {
@@ -1473,6 +1473,7 @@ class DataFrame private[sql](
* the RDD in order via the simple statement
* `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().jdbc()`.
*/
@deprecated("Use write.jdbc()", "1.4.0")
def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = {
@@ -1485,6 +1486,7 @@ class DataFrame private[sql](
* Files that are written out using this method can be read back in as a [[DataFrame]]
* using the `parquetFile` function in [[SQLContext]].
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().parquet()`.
*/
@deprecated("Use write.parquet(path)", "1.4.0")
def saveAsParquetFile(path: String): Unit = {
@@ -1508,6 +1510,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`.
*/
@deprecated("Use write.saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String): Unit = {
@@ -1526,6 +1529,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`.
*/
@deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String, mode: SaveMode): Unit = {
@@ -1545,6 +1549,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String, source: String): Unit = {
@@ -1564,6 +1569,7 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0")
def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = {
@@ -1582,6 +1588,8 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)",
"1.4.0")
@@ -1606,6 +1614,8 @@ class DataFrame private[sql](
* Also note that while this function can persist the table metadata into Hive's metastore,
* the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)",
"1.4.0")
@@ -1622,6 +1632,7 @@ class DataFrame private[sql](
* using the default data source configured by spark.sql.sources.default and
* [[SaveMode.ErrorIfExists]] as the save mode.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().save(path)`.
*/
@deprecated("Use write.save(path)", "1.4.0")
def save(path: String): Unit = {
@@ -1632,6 +1643,7 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode,
* using the default data source configured by spark.sql.sources.default.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`.
*/
@deprecated("Use write.mode(mode).save(path)", "1.4.0")
def save(path: String, mode: SaveMode): Unit = {
@@ -1642,6 +1654,7 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame to the given path based on the given data source,
* using [[SaveMode.ErrorIfExists]] as the save mode.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`.
*/
@deprecated("Use write.format(source).save(path)", "1.4.0")
def save(path: String, source: String): Unit = {
@@ -1652,6 +1665,7 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame to the given path based on the given data source and
* [[SaveMode]] specified by mode.
* @group output
+ * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`.
*/
@deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0")
def save(path: String, source: String, mode: SaveMode): Unit = {
@@ -1662,6 +1676,8 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame based on the given data source,
* [[SaveMode]] specified by mode, and a set of options.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).save(path)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0")
def save(
@@ -1676,6 +1692,8 @@ class DataFrame private[sql](
* Saves the contents of this DataFrame based on the given data source,
* [[SaveMode]] specified by mode, and a set of options
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().format(source).mode(mode).options(options).save(path)`.
*/
@deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0")
def save(
@@ -1689,6 +1707,8 @@ class DataFrame private[sql](
/**
* Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`.
*/
@deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0")
def insertInto(tableName: String, overwrite: Boolean): Unit = {
@@ -1699,6 +1719,8 @@ class DataFrame private[sql](
* Adds the rows from this RDD to the specified table.
* Throws an exception if the table already exists.
* @group output
+ * @deprecated As of 1.4.0, replaced by
+ * `write().mode(SaveMode.Append).saveAsTable(tableName)`.
*/
@deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0")
def insertInto(tableName: String): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 5d106c1ac2674..edb9ed7bba56a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -43,7 +43,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
- * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
+ * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @param col1 the name of the column
@@ -97,6 +97,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
* than 1e-4.
@@ -114,6 +117,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* Uses a `default` support of 1%.
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
@@ -128,6 +134,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
@@ -143,6 +152,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* Uses a `default` support of 1%.
*
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting [[DataFrame]].
+ *
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 516ba2ac23371..45b3e1bc627d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -40,22 +40,22 @@ private[sql] object GroupedData {
/**
* The Grouping Type
*/
- trait GroupType
+ private[sql] trait GroupType
/**
* To indicate it's the GroupBy
*/
- object GroupByType extends GroupType
+ private[sql] object GroupByType extends GroupType
/**
* To indicate it's the CUBE
*/
- object CubeType extends GroupType
+ private[sql] object CubeType extends GroupType
/**
* To indicate it's the ROLLUP
*/
- object RollupType extends GroupType
+ private[sql] object RollupType extends GroupType
}
/**
@@ -249,7 +249,7 @@ class GroupedData protected[sql](
def mean(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average)
}
-
+
/**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a32897c20b474..91e6385dec81b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -182,9 +182,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
conf.dialect
}
- sparkContext.getConf.getAll.foreach {
- case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
- case _ =>
+ {
+ // We extract spark sql settings from SparkContext's conf and put them to
+ // Spark SQL's conf.
+ // First, we populate the SQLConf (conf). So, we can make sure that other values using
+ // those settings in their construction can get the correct settings.
+ // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version
+ // and spark.sql.hive.metastore.jars to get correctly constructed.
+ val properties = new Properties
+ sparkContext.getConf.getAll.foreach {
+ case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value)
+ case _ =>
+ }
+ // We directly put those settings to conf to avoid of calling setConf, which may have
+ // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive
+ // get constructed. If we call setConf directly, the constructed metadataHive may have
+ // wrong settings, or the construction may fail.
+ conf.setConf(properties)
+ // After we have populated SQLConf, we call setConf to populate other confs in the subclass
+ // (e.g. hiveconf in HiveContext).
+ properties.foreach {
+ case (key, value) => setConf(key, value)
+ }
}
@transient
@@ -1021,21 +1040,33 @@ class SQLContext(@transient val sparkContext: SparkContext)
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema)
}
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema)
}
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
createDataFrame(rdd, beanClass)
}
+ /**
+ * @deprecated As of 1.3.0, replaced by `createDataFrame()`.
+ */
@deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
createDataFrame(rdd, beanClass)
@@ -1046,6 +1077,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* [[DataFrame]] if no paths are passed in.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().parquet()`.
*/
@deprecated("Use read.parquet()", "1.4.0")
@scala.annotation.varargs
@@ -1065,6 +1097,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It goes through the entire dataset once to determine the schema.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonFile(path: String): DataFrame = {
@@ -1076,6 +1109,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonFile(path: String, schema: StructType): DataFrame = {
@@ -1084,6 +1118,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonFile(path: String, samplingRatio: Double): DataFrame = {
@@ -1096,6 +1131,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It goes through the entire dataset once to determine the schema.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: RDD[String]): DataFrame = read.json(json)
@@ -1106,6 +1142,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It goes through the entire dataset once to determine the schema.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json)
@@ -1115,6 +1152,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
@@ -1126,6 +1164,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* schema, returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = {
@@ -1137,6 +1176,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* schema, returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
@@ -1148,6 +1188,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* schema, returning the result as a [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().json()`.
*/
@deprecated("Use read.json()", "1.4.0")
def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = {
@@ -1159,6 +1200,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* using the default data source configured by spark.sql.sources.default.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().load(path)`.
*/
@deprecated("Use read.load(path)", "1.4.0")
def load(path: String): DataFrame = {
@@ -1169,6 +1211,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns the dataset stored at path as a DataFrame, using the given data source.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`.
*/
@deprecated("Use read.format(source).load(path)", "1.4.0")
def load(path: String, source: String): DataFrame = {
@@ -1180,6 +1223,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`.
*/
@deprecated("Use read.format(source).options(options).load()", "1.4.0")
def load(source: String, options: java.util.Map[String, String]): DataFrame = {
@@ -1191,6 +1235,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`.
*/
@deprecated("Use read.format(source).options(options).load()", "1.4.0")
def load(source: String, options: Map[String, String]): DataFrame = {
@@ -1202,6 +1247,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by
+ * `read().format(source).schema(schema).options(options).load()`.
*/
@deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0")
def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame =
@@ -1214,6 +1261,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
* a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
*
* @group genericdata
+ * @deprecated As of 1.4.0, replaced by
+ * `read().format(source).schema(schema).options(options).load()`.
*/
@deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0")
def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = {
@@ -1225,6 +1274,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* url named table.
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().jdbc()`.
*/
@deprecated("use read.jdbc()", "1.4.0")
def jdbc(url: String, table: String): DataFrame = {
@@ -1242,6 +1292,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split
* evenly into this many partitions
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().jdbc()`.
*/
@deprecated("use read.jdbc()", "1.4.0")
def jdbc(
@@ -1261,6 +1312,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* of the [[DataFrame]].
*
* @group specificdata
+ * @deprecated As of 1.4.0, replaced by `read().jdbc()`.
*/
@deprecated("use read.jdbc()", "1.4.0")
def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 423ecdff5804a..604f3124e23ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -106,7 +106,7 @@ private[r] object SQLUtils {
dfCols.map { col =>
colToRBytes(col)
- }
+ }
}
def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = {
@@ -121,7 +121,7 @@ private[r] object SQLUtils {
val numRows = col.length
val bos = new ByteArrayOutputStream()
val dos = new DataOutputStream(bos)
-
+
SerDe.writeInt(dos, numRows)
col.map { item =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 2ec7d4fbc92de..3e27c1bde2dfd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -138,15 +138,15 @@ case class GeneratedAggregate(
case UnscaledValue(e) => e
case _ => expr
}
- // partial sum result can be null only when no input rows present
+ // partial sum result can be null only when no input rows present
val updateFunction = If(
IsNotNull(actualExpr),
Coalesce(
Add(
- Coalesce(currentSum :: zero :: Nil),
+ Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)) :: currentSum :: zero :: Nil),
currentSum)
-
+
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
@@ -155,7 +155,7 @@ case class GeneratedAggregate(
}
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
-
+
case m @ Max(expr) =>
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
val initialValue = Literal.create(null, expr.dataType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6cb67b4bbbb65..a30ade86441ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -65,7 +65,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
* :: DeveloperApi ::
* Sample the dataset.
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
- * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index fe8a81e3d0434..c41c21c0eeb50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -62,7 +62,7 @@ private[sql] object FrequentItems extends Logging {
}
/**
- * Finding frequent items for columns, possibly with false positives. Using the
+ * Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
@@ -75,7 +75,7 @@ private[sql] object FrequentItems extends Logging {
* @return A Local DataFrame with the Array of frequent items for each column.
*/
private[sql] def singlePassFreqItems(
- df: DataFrame,
+ df: DataFrame,
cols: Seq[String],
support: Double): DataFrame = {
require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
@@ -88,7 +88,7 @@ private[sql] object FrequentItems extends Logging {
val index = originalSchema.fieldIndex(name)
(name, originalSchema.fields(index).dataType)
}
-
+
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index d22f5fd2d439c..93383e5a62f11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -18,14 +18,14 @@
package org.apache.spark.sql.execution.stat
import org.apache.spark.Logging
-import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.{Row, Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
private[sql] object StatFunctions extends Logging {
-
+
/** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols)
@@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging {
s"exceed 1e4. Currently $columnSize")
val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
val countsRow = new GenericMutableRow(columnSize + 1)
- rows.foreach { row =>
+ rows.foreach { (row: Row) =>
+ // row.get(0) is column 1
+ // row.get(1) is column 2
+ // row.get(3) is the frequency
countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
}
// the value of col1 is the first value, the rest are the counts
@@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging {
val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
val schema = StructType(StructField(tableName, StringType) +: headerNames)
- new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
+ new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index d4003b2d9cbf6..e9b60841fc28c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -79,3 +79,20 @@ object Window {
}
}
+
+/**
+ * :: Experimental ::
+ * Utility functions for defining window in DataFrames.
+ *
+ * {{{
+ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
+ * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0)
+ *
+ * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING
+ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3)
+ * }}}
+ *
+ * @since 1.4.0
+ */
+@Experimental
+class Window private() // So we can see Window in JavaDoc.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6dc17bbb2e768..77327f2b84eaa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1299,7 +1299,7 @@ object functions {
* @since 1.4.0
*/
def toRadians(columnName: String): Column = toRadians(Column(columnName))
-
+
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 0bdb68e8ac845..40b604d710dce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -262,7 +262,7 @@ private[sql] class JDBCRDD(
}
private def escapeSql(value: String): String =
- if (value == null) null else StringUtils.replace(value, "'", "''")
+ if (value == null) null else StringUtils.replace(value, "'", "''")
/**
* Turns a single Filter into a String representing a SQL expression.
@@ -304,7 +304,7 @@ private[sql] class JDBCRDD(
// Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
// we don't have to potentially poke around in the Metadata once for every
- // row.
+ // row.
// Is there a better way to do this? I'd rather be using a type that
// contains only the tags I define.
abstract class JDBCConversion
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 09d6865457df6..30f9190d45bf8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -54,7 +54,7 @@ private[sql] object JDBCRelation {
if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
// Overflow and silliness can happen if you subtract then divide.
// Here we get a little roundoff, but that's (hopefully) OK.
- val stride: Long = (partitioning.upperBound / numPartitions
+ val stride: Long = (partitioning.upperBound / numPartitions
- partitioning.lowerBound / numPartitions)
var i: Int = 0
var currentValue: Long = partitioning.lowerBound
@@ -140,10 +140,10 @@ private[sql] case class JDBCRelation(
filters,
parts)
}
-
+
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(url, table, properties)
- }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index f21dd29aca37f..dd8aaf6474895 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -240,10 +240,10 @@ package object jdbc {
}
}
}
-
+
def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
- case driver => driver.getClass.getCanonicalName
+ case driver => driver.getClass.getCanonicalName
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 3f97a11ceb97d..4e94fd07a8771 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -44,6 +44,7 @@ package object sql {
/**
* Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala.
+ * @deprecated As of 1.3.0, replaced by `DataFrame`.
*/
@deprecated("1.3.0", "use DataFrame")
type SchemaRDD = DataFrame
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 1b4196ab0be35..caa9f045537d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -243,8 +243,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
/**
* Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in
* a long (i.e. precision <= 18)
+ *
+ * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object.
*/
- protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = {
+ protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = {
val precision = ctype.precisionInfo.get.precision
val scale = ctype.precisionInfo.get.scale
val bytes = value.getBytes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 8b3e1b2b59bf6..824ae36968c32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -155,7 +155,7 @@ private[sql] class ParquetRelation2(
meta
}
- override def equals(other: scala.Any): Boolean = other match {
+ override def equals(other: Any): Boolean = other match {
case that: ParquetRelation2 =>
val schemaEquality = if (shouldMergeSchemas) {
this.shouldMergeSchemas == that.shouldMergeSchemas
@@ -190,7 +190,7 @@ private[sql] class ParquetRelation2(
}
}
- override def dataSchema: StructType = metadataCache.dataSchema
+ override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema)
override private[sql] def refresh(): Unit = {
super.refresh()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
index dafdf0f8b4564..c4c99de5a38dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala
@@ -187,7 +187,7 @@ private[sql] object PartitioningUtils {
Seq.empty
} else {
assert(distinctPartitionsColNames.size == 1, {
- val list = distinctPartitionsColNames.mkString("\t", "\n", "")
+ val list = distinctPartitionsColNames.mkString("\t", "\n\t", "")
s"Conflicting partition column names detected:\n$list"
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
index a74a98631da35..ebad0c1564ec0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala
@@ -216,7 +216,7 @@ private[sql] class SqlNewHadoopRDD[K, V](
override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = {
val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value
val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
- case Some(c) =>
+ case Some(c) =>
try {
val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
Some(HadoopRDD.convertSplitLocationInfo(infos))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 3132067d562f6..71f016b1f14de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -30,9 +30,10 @@ import org.apache.spark._
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode}
@@ -94,10 +95,19 @@ private[sql] case class InsertIntoHadoopFsRelation(
// We create a DataFrame by applying the schema of relation to the data to make sure.
// We are writing data based on the expected schema,
- val df = sqlContext.createDataFrame(
- DataFrame(sqlContext, query).queryExecution.toRdd,
- relation.schema,
- needsConversion = false)
+ val df = {
+ // For partitioned relation r, r.schema's column ordering can be different from the column
+ // ordering of data.logicalPlan (partition columns are all moved after data column). We
+ // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can
+ // safely apply the schema of r.schema to the data.
+ val project = Project(
+ relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query)
+
+ sqlContext.createDataFrame(
+ DataFrame(sqlContext, project).queryExecution.toRdd,
+ relation.schema,
+ needsConversion = false)
+ }
val partitionColumns = relation.partitionColumns.fieldNames
if (partitionColumns.isEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 22587f5a1c6f1..20afd60cb7767 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.RunnableCommand
@@ -322,19 +322,13 @@ private[sql] object ResolvedDataSource {
Some(partitionColumnsSchema(data.schema, partitionColumns)),
caseInsensitiveOptions)
- // For partitioned relation r, r.schema's column ordering is different with the column
- // ordering of data.logicalPlan. We need a Project to adjust the ordering.
- // So, inside InsertIntoHadoopFsRelation, we can safely apply the schema of r.schema to
- // the data.
- val project =
- Project(
- r.schema.map(field => new UnresolvedAttribute(Seq(field.name))),
- data.logicalPlan)
-
+ // For partitioned relation r, r.schema's column ordering can be different from the column
+ // ordering of data.logicalPlan (partition columns are all moved after data column). This
+ // will be adjusted within InsertIntoHadoopFsRelation.
sqlContext.executePlan(
InsertIntoHadoopFsRelation(
r,
- project,
+ data.logicalPlan,
mode)).toRdd
r
case _ =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index c06026e042d9f..f5bd2d2941ca0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -93,7 +93,7 @@ trait SchemaRelationProvider {
}
/**
- * ::DeveloperApi::
+ * ::Experimental::
* Implemented by objects that produce relations for a specific kind of data source
* with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a
* USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined
@@ -115,6 +115,7 @@ trait SchemaRelationProvider {
*
* @since 1.4.0
*/
+@Experimental
trait HadoopFsRelationProvider {
/**
* Returns a new base relation with the given parameters, a user defined schema, and a list of
@@ -378,10 +379,10 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]]
def refresh(): Unit = {
- // We don't filter files/directories whose name start with "_" or "." here, as specific data
- // sources may take advantages over them (e.g. Parquet _metadata and _common_metadata files).
- // But "_temporary" directories are explicitly ignored since failed tasks/jobs may leave
- // partial/corrupted data files there.
+ // We don't filter files/directories whose name start with "_" except "_temporary" here, as
+ // specific data sources may take advantages over them (e.g. Parquet _metadata and
+ // _common_metadata files). "_temporary" directories are explicitly ignored since failed
+ // tasks/jobs may leave partial/corrupted data files there.
def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = {
if (status.getPath.getName.toLowerCase == "_temporary") {
Set.empty
@@ -399,6 +400,9 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
val fs = hdfsPath.getFileSystem(hadoopConf)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _))
+ }.filterNot { status =>
+ // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories
+ status.getPath.getName.startsWith(".")
}
val files = statuses.filterNot(_.isDir)
@@ -499,7 +503,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
*/
override lazy val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
- StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column =>
+ StructType(dataSchema ++ partitionColumns.filterNot { column =>
dataSchemaColumnNames.contains(column.name.toLowerCase)
})
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index d006b83fc075a..bfba379d9a518 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.scalatest.Matchers._
+import org.apache.spark.sql.execution.Project
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -27,6 +28,12 @@ import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
import org.apache.spark.sql.TestData._
+ test("alias") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ assert(df.select(df("a").as("b")).columns.head === "b")
+ assert(df.select(df("a").alias("b")).columns.head === "b")
+ }
+
test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
@@ -446,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest {
}
test("rand") {
- val randCol = testData.select('key, rand(5L).as("rand"))
+ val randCol = testData.select($"key", rand(5L).as("rand"))
randCol.columns.length should be (2)
val rows = randCol.collect()
rows.foreach { row =>
assert(row.getDouble(1) <= 1.0)
assert(row.getDouble(1) >= 0.0)
}
+
+ def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = {
+ val projects = df.queryExecution.executedPlan.collect {
+ case project: Project => project
+ }
+ assert(projects.size === expectedNumProjects)
+ }
+
+ // We first create a plan with two Projects.
+ // Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // Project [key, (Rand 5 + 1) AS rand]
+ // LogicalRDD [key, value]
+ // Because Rand function is not deterministic, the column rand is not deterministic.
+ // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // and Project [key, Rand 5 AS rand]. The final plan still has two Projects.
+ val dfWithTwoProjects =
+ testData
+ .select($"key", (rand(5L) + 1).as("rand"))
+ .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2"))
+ checkNumProjects(dfWithTwoProjects, 2)
+
+ // Now, we add one more project rand1 - rand2 on top of the query plan.
+ // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated
+ // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2.
+ // So, the plan will be optimized from ...
+ // Project [(rand1 - rand2) AS (rand1 - rand2)]
+ // Project [rand + 1 AS rand1, rand - 1 AS rand2]
+ // Project [key, (Rand 5 + 1) AS rand]
+ // LogicalRDD [key, value]
+ // to ...
+ // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)]
+ // Project [key, Rand 5 AS rand]
+ // LogicalRDD [key, value]
+ val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2")
+ checkNumProjects(dfWithThreeProjects, 2)
+ dfWithThreeProjects.collect().foreach { row =>
+ assert(row.getDouble(0) === 2.0 +- 0.0001)
+ }
}
test("randn") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 46b1845a9180c..438f479459dfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.sql
-import org.scalatest.FunSuite
import org.scalatest.Matchers._
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
-class DataFrameStatSuite extends FunSuite {
-
+class DataFrameStatSuite extends SparkFunSuite {
+
val sqlCtx = TestSQLContext
def toLetter(i: Int): String = (i + 97).toChar.toString
@@ -74,10 +74,10 @@ class DataFrameStatSuite extends FunSuite {
val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
assert(rows(0).get(0).toString === "0")
assert(rows(0).getLong(1) === 2L)
- assert(rows(0).get(2) === null)
+ assert(rows(0).get(2) === 0L)
assert(rows(1).get(0).toString === "1")
assert(rows(1).getLong(1) === 1L)
- assert(rows(1).get(2) === null)
+ assert(rows(1).get(2) === 0L)
assert(rows(2).get(0).toString === "2")
assert(rows(2).getLong(1) === 2L)
assert(rows(2).getLong(2) === 1L)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index c4281c4b55c02..dd68965444f5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -206,7 +206,7 @@ class MathExpressionsSuite extends QueryTest {
}
test("log") {
- testOneToOneNonNegativeMathFunction(log, math.log)
+ testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
}
test("log10") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index fb3ba4bc1b908..513ac915dcb2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -17,15 +17,15 @@
package org.apache.spark.sql
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-class RowSuite extends FunSuite {
+class RowSuite extends SparkFunSuite {
test("create row") {
val expected = new GenericMutableRow(4)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index bf73d0c7074a5..3a5f071e2f7cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql
-import org.scalatest.FunSuiteLike
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
-class SQLConfSuite extends QueryTest with FunSuiteLike {
+class SQLConfSuite extends QueryTest {
val testKey = "test.key.0"
val testVal = "test.val.0"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index f186bc1c18123..797d123b48668 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.TestSQLContext
-class SQLContextSuite extends FunSuite with BeforeAndAfterAll {
+class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll {
private val testSqlContext = TestSQLContext
private val testSparkContext = TestSQLContext.sparkContext
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bf18bf854aa4a..63f7d314fb699 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
import org.apache.spark.sql.types._
@@ -32,12 +32,12 @@ import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect
-class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
+class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Make sure the tables are loaded.
TestData
- import org.apache.spark.sql.test.TestSQLContext.implicits._
- val sqlCtx = TestSQLContext
+ val sqlContext = TestSQLContext
+ import sqlContext.implicits._
test("SPARK-6743: no columns from cache") {
Seq(
@@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
+ val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
+ val df2 = createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
+ val df3 = createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
+ val personWithMeta = createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
@@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
}
+
+ test("SPARK-7952: fix the equality check between boolean and numeric types") {
+ withTempTable("t") {
+ // numeric field i, boolean field j, result of i = j, result of i <=> j
+ Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
+ (1, true, true, true),
+ (0, false, true, true),
+ (2, true, false, false),
+ (2, false, false, false),
+ (null, true, null, false),
+ (null, false, null, false),
+ (0, null, null, false),
+ (1, null, null, false),
+ (null, null, null, true)
+ ).toDF("i", "b", "r1", "r2").registerTempTable("t")
+
+ checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
+ checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 52d265b445e14..d2ede39f0a5f6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext._
@@ -74,7 +73,7 @@ case class ComplexReflectData(
mapFieldContainsNull: Map[Int, Option[Long]],
dataField: Data)
-class ScalaReflectionRelationSuite extends FunSuite {
+class ScalaReflectionRelationSuite extends SparkFunSuite {
import org.apache.spark.sql.test.TestSQLContext.implicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index 6f6d3c9c243d4..1e8cde606b67b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -17,13 +17,11 @@
package org.apache.spark.sql
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.test.TestSQLContext
-class SerializationSuite extends FunSuite {
+class SerializationSuite extends SparkFunSuite {
test("[SPARK-5235] SQLContext should be serializable") {
val sqlContext = new SQLContext(TestSQLContext.sparkContext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 7cefcf44061ce..339e719f39f16 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.columnar
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.types._
-class ColumnStatsSuite extends FunSuite {
+class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0))
testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0))
testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 061efb37a0ac3..a1e76eaa982cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -23,15 +23,14 @@ import java.sql.Timestamp
import com.esotericsoftware.kryo.{Serializer, Kryo}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.serializer.KryoRegistrator
-import org.scalatest.FunSuite
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
-class ColumnTypeSuite extends FunSuite with Logging {
+class ColumnTypeSuite extends SparkFunSuite with Logging {
val DEFAULT_BUFFER_SIZE = 512
test("defaultSize") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index a0702144f942c..2a6e0c376551a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.DataType
@@ -39,7 +38,7 @@ object TestNullableColumnAccessor {
}
}
-class NullableColumnAccessorSuite extends FunSuite {
+class NullableColumnAccessorSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 3a5605d2335d7..cb4e9f1eb7f46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.columnar
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
@@ -35,7 +34,7 @@ object TestNullableColumnBuilder {
}
}
-class NullableColumnBuilderSuite extends FunSuite {
+class NullableColumnBuilderSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 2a0b701cad7fa..cda1b0992e36f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.sql.columnar
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
-class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
+class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = conf.columnBatchSize
val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index 8b518f094174c..20d65a74e3b7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.columnar.compression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
import org.apache.spark.sql.columnar.ColumnarTestUtils._
-class BooleanBitSetSuite extends FunSuite {
+class BooleanBitSetSuite extends SparkFunSuite {
import BooleanBitSet._
def skeleton(count: Int) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index cef60ec204faa..acfab6586c0d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql.columnar.compression
import java.nio.ByteBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.AtomicType
-class DictionaryEncodingSuite extends FunSuite {
+class DictionaryEncodingSuite extends SparkFunSuite {
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
testDictionaryEncoding(new StringColumnStats, STRING)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 5514590541dd6..2111e9fbe62cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.columnar.compression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.IntegralType
-class IntegralDeltaSuite extends FunSuite {
+class IntegralDeltaSuite extends SparkFunSuite {
testIntegralDelta(new IntColumnStats, INT, IntDelta)
testIntegralDelta(new LongColumnStats, LONG, LongDelta)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 6ee48f6291914..67ec08f594a43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.columnar.compression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types.AtomicType
-class RunLengthEncodingSuite extends FunSuite {
+class RunLengthEncodingSuite extends SparkFunSuite {
testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
testRunLengthEncoding(new ByteColumnStats, BYTE)
testRunLengthEncoding(new ShortColumnStats, SHORT)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 523be56df65ba..45a7e8fe68f72 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.execution
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{SQLConf, execution}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
@@ -31,7 +30,7 @@ import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
-class PlannerSuite extends FunSuite {
+class PlannerSuite extends SparkFunSuite {
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
index 15337c4045436..6ca5390cde23e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -19,17 +19,17 @@ package org.apache.spark.sql.execution
import java.sql.{Timestamp, Date}
-import org.scalatest.{FunSuite, BeforeAndAfterAll}
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.Serializer
-import org.apache.spark.ShuffleDependency
+import org.apache.spark.{ShuffleDependency, SparkFunSuite}
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
-class SparkSqlSerializer2DataTypeSuite extends FunSuite {
+class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite {
// Make sure that we will not use serializer2 for unsupported data types.
def checkSupported(dataType: DataType, isSupported: Boolean): Unit = {
val testName =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 358d8cf06e463..8ec3985e00360 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.execution.debug
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext._
-class DebuggingSuite extends FunSuite {
+class DebuggingSuite extends SparkFunSuite {
test("DataFrame.debug()") {
testData.debug()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 2aad01ded1acf..5290c28cfca02 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.sql.execution.joins
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
import org.apache.spark.util.collection.CompactBuffer
-class HashedRelationSuite extends FunSuite {
+class HashedRelationSuite extends SparkFunSuite {
// Key is simply the record itself
private val keyProjection = new Projection {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 30279f528944b..e20c66cb2f1d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -21,14 +21,15 @@ import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar, Properties}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
import org.h2.jdbc.JdbcSQLException
-import org.scalatest.{FunSuite, BeforeAndAfter}
+import org.scalatest.BeforeAndAfter
import TestSQLContext._
import TestSQLContext.implicits._
-class JDBCSuite extends FunSuite with BeforeAndAfter {
+class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null
@@ -67,7 +68,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
-
+
sql(
s"""
|CREATE TEMPORARY TABLE fetchtwo
@@ -75,7 +76,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
| fetchSize '2')
""".stripMargin.replaceAll("\n", " "))
-
+
sql(
s"""
|CREATE TEMPORARY TABLE parts
@@ -208,7 +209,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(ids(1) === 2)
assert(ids(2) === 3)
}
-
+
test("SELECT second field when fetchSize is two") {
val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _)
assert(ids.size === 3)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 2e4c12f9da80c..2de8c1a6098e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -20,13 +20,14 @@ package org.apache.spark.sql.jdbc
import java.sql.DriverManager
import java.util.Properties
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{SaveMode, Row}
import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
-class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
+class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb2"
var conn: java.sql.Connection = null
val url1 = "jdbc:h2:mem:testdb3"
@@ -35,12 +36,12 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
properties.setProperty("user", "testUser")
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
-
+
before {
Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url)
conn.prepareStatement("create schema test").executeUpdate()
-
+
conn1 = DriverManager.getConnection(url1, properties)
conn1.prepareStatement("create schema test").executeUpdate()
conn1.prepareStatement("drop table if exists test.people").executeUpdate()
@@ -52,20 +53,20 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
conn1.prepareStatement(
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn1.commit()
-
+
TestSQLContext.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
-
+
TestSQLContext.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE1
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass')
- """.stripMargin.replaceAll("\n", " "))
+ """.stripMargin.replaceAll("\n", " "))
}
after {
@@ -151,5 +152,5 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
- }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
index f231589e9674d..3b29979452ad9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
@@ -18,10 +18,11 @@ package org.apache.spark.sql.parquet
import java.io.File
import java.math.BigInteger
-import java.sql.{Timestamp, Date}
+import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer
+import com.google.common.io.Files
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.catalyst.expressions.Literal
@@ -432,4 +433,20 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
checkAnswer(read.load(dir.toString).select(fields: _*), row)
}
}
+
+ test("SPARK-8037: Ignores files whose name starts with dot") {
+ withTempPath { dir =>
+ val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d")
+
+ df.write
+ .format("parquet")
+ .partitionBy("b", "c", "d")
+ .save(dir.getCanonicalPath)
+
+ Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store"))
+ Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar"))
+
+ checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index b98ba09ccfc2d..304936fb2be8e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.parquet
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
@@ -111,6 +112,18 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
List(Row("same", "run_5", 100)))
}
}
+
+ test("SPARK-6917 DecimalType should work with non-native types") {
+ val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i)))
+ val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
+ StructField("time", TimestampType, false)).toArray)
+ withTempPath { file =>
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema)
+ df.write.parquet(file.getCanonicalPath)
+ val df2 = sqlContext.read.parquet(file.getCanonicalPath)
+ checkAnswer(df2, df.collect().toSeq)
+ }
+ }
}
class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
index c964b6d984557..caec2a6f25489 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.parquet
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import org.scalatest.FunSuite
import parquet.schema.MessageTypeParser
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
-class ParquetSchemaSuite extends FunSuite with ParquetTest {
+class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
val sqlContext = TestSQLContext
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
index 8331a14c9295c..296b0d6f74a0c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.sources
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class ResolvedDataSourceSuite extends FunSuite {
+class ResolvedDataSourceSuite extends SparkFunSuite {
test("builtin sources") {
assert(ResolvedDataSource.lookupDataSource("jdbc") ===
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index 437f697d25bf3..20d3c7d4c5959 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -41,6 +41,13 @@
spark-hive_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ com.google.guavaguava
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 3458b04bfba0f..94687eeda4179 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -17,23 +17,23 @@
package org.apache.spark.sql.hive.thriftserver
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService}
import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor}
-import org.apache.spark.sql.SQLConf
-import org.apache.spark.{SparkContext, SparkConf, Logging}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.hive.HiveContext
+import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart}
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
-import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener}
import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab
+import org.apache.spark.sql.hive.{HiveContext, HiveShim}
import org.apache.spark.util.Utils
-
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.{Logging, SparkContext}
/**
* The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a
@@ -51,6 +51,7 @@ object HiveThriftServer2 extends Logging {
@DeveloperApi
def startWithContext(sqlContext: HiveContext): Unit = {
val server = new HiveThriftServer2(sqlContext)
+ sqlContext.setConf("spark.sql.hive.version", HiveShim.version)
server.init(sqlContext.hiveconf)
server.start()
listener = new HiveThriftServer2Listener(server, sqlContext.conf)
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index cc07db827d359..3732af7870b93 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -25,16 +25,16 @@ import scala.concurrent.{Await, Promise}
import scala.sys.process.{Process, ProcessLogger}
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.util.Utils
/**
* A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary
* Hive metastore and warehouse.
*/
-class CliSuite extends FunSuite with BeforeAndAfter with Logging {
+class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val warehousePath = Utils.createTempDir()
val metastorePath = Utils.createTempDir()
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 1fadea97fd07f..a93a3dee43511 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -27,6 +27,8 @@ import scala.concurrent.{Await, Promise}
import scala.sys.process.{Process, ProcessLogger}
import scala.util.{Random, Try}
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.Files
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
import org.apache.hive.service.auth.PlainSaslHelper
@@ -35,9 +37,9 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client
import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.transport.TSocket
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.hive.HiveShim
import org.apache.spark.util.Utils
@@ -54,7 +56,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
override def mode: ServerMode.Value = ServerMode.binary
private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = {
- // Transport creation logics below mimics HiveConnection.createBinaryTransport
+ // Transport creation logic below mimics HiveConnection.createBinaryTransport
val rawTransport = new TSocket("localhost", serverPort)
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
@@ -391,10 +393,10 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
val statements = connections.map(_.createStatement())
try {
- statements.zip(fs).map { case (s, f) => f(s) }
+ statements.zip(fs).foreach { case (s, f) => f(s) }
} finally {
- statements.map(_.close())
- connections.map(_.close())
+ statements.foreach(_.close())
+ connections.foreach(_.close())
}
}
@@ -403,7 +405,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
}
}
-abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging {
+abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging {
def mode: ServerMode.Value
private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")
@@ -433,15 +435,33 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit
ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
}
+ val driverClassPath = {
+ // Writes a temporary log4j.properties and prepend it to driver classpath, so that it
+ // overrides all other potential log4j configurations contained in other dependency jar files.
+ val tempLog4jConf = Utils.createTempDir().getCanonicalPath
+
+ Files.write(
+ """log4j.rootCategory=INFO, console
+ |log4j.appender.console=org.apache.log4j.ConsoleAppender
+ |log4j.appender.console.target=System.err
+ |log4j.appender.console.layout=org.apache.log4j.PatternLayout
+ |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+ """.stripMargin,
+ new File(s"$tempLog4jConf/log4j.properties"),
+ UTF_8)
+
+ tempLog4jConf + File.pathSeparator + sys.props("java.class.path")
+ }
+
s"""$startScript
| --master local
- | --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
| --hiveconf $portConf=$port
- | --driver-class-path ${sys.props("java.class.path")}
+ | --driver-class-path $driverClassPath
+ | --driver-java-options -Dlog4j.debug
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
}
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 615b07e74d535..923ffabb9b99e 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -41,6 +41,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-sql_${scala.binary.version}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 47b85731587d5..ca1f49b546bd7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -596,7 +596,7 @@ private[hive] case class MetastoreRelation
self: Product =>
- override def equals(other: scala.Any): Boolean = other match {
+ override def equals(other: Any): Boolean = other match {
case relation: MetastoreRelation =>
databaseName == relation.databaseName &&
tableName == relation.tableName &&
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 3915ee835685f..a5ca3613c5e00 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -57,7 +57,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan {
override def output: Seq[Attribute] = Seq.empty
}
-case class CreateTableAsSelect(
+private[hive] case class CreateTableAsSelect(
tableDesc: HiveTable,
child: LogicalPlan,
allowExisting: Boolean) extends UnaryNode with Command {
@@ -1561,6 +1561,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
""".stripMargin)
}
+ /* Case insensitive matches for Window Specification */
+ val PRECEDING = "(?i)preceding".r
+ val FOLLOWING = "(?i)following".r
+ val CURRENT = "(?i)current".r
def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match {
case Token(windowName, Nil) :: Nil =>
// Refer to a window spec defined in the window clause.
@@ -1614,11 +1618,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
} else {
val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame)
def nodeToBoundary(node: Node): FrameBoundary = node match {
- case Token("preceding", Token(count, Nil) :: Nil) =>
- if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt)
- case Token("following", Token(count, Nil) :: Nil) =>
- if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt)
- case Token("current", Nil) => CurrentRow
+ case Token(PRECEDING(), Token(count, Nil) :: Nil) =>
+ if (count.toLowerCase() == "unbounded") {
+ UnboundedPreceding
+ } else {
+ ValuePreceding(count.toInt)
+ }
+ case Token(FOLLOWING(), Token(count, Nil) :: Nil) =>
+ if (count.toLowerCase() == "unbounded") {
+ UnboundedFollowing
+ } else {
+ ValueFollowing(count.toInt)
+ }
+ case Token(CURRENT(), Nil) => CurrentRow
case _ =>
throw new NotImplementedError(
s"""No parse rules for the Window Frame Boundary based on Node ${node.getName}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index 196a3d836cab2..16851fdd71a98 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -90,14 +90,14 @@ private[hive] object IsolatedClientLoader {
* `ClientInterface`, unless `isolationOn` is set to `false`.
*
* @param version The version of hive on the classpath. used to pick specific function signatures
- * that are not compatibile accross versions.
+ * that are not compatible across versions.
* @param execJars A collection of jar files that must include hive and hadoop.
* @param config A set of options that will be added to the HiveConf of the constructed client.
* @param isolationOn When true, custom versions of barrier classes will be constructed. Must be
* true unless loading the version of hive that is on Sparks classloader.
- * @param rootClassLoader The system root classloader. Must not know about hive classes.
- * @param baseClassLoader The spark classloader that is used to load shared classes.
- *
+ * @param rootClassLoader The system root classloader.
+ * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know
+ * about Hive classes.
*/
private[hive] class IsolatedClientLoader(
val version: HiveVersion,
@@ -110,7 +110,7 @@ private[hive] class IsolatedClientLoader(
val barrierPrefixes: Seq[String] = Seq.empty)
extends Logging {
- // Check to make sure that the root classloader does not know about Hive.
+ // Check to make sure that the base classloader does not know about Hive.
assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure)
/** All jars used by the hive specific classloader. */
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala
index c600b158c5460..4d053ae42c2ea 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala
@@ -30,7 +30,7 @@ private[client] object ReflectionException {
/**
* Provides implicit functions on any object for calling methods reflectively.
*/
-protected trait ReflectionMagic {
+private[client] trait ReflectionMagic {
/** code for InstanceMagic
println(
(1 to 22).map { n =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
index 7db9200d47440..410d9881ac214 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
@@ -29,5 +29,5 @@ package object client {
case object v13 extends HiveVersion("0.13.1", false)
}
// scalastyle:on
-
+
}
\ No newline at end of file
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index 62dc4167b78dd..11ee5503146b9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -63,7 +63,7 @@ case class HiveTableScan(
BindReferences.bindReference(pred, relation.partitionKeys)
}
- // Create a local copy of hiveconf,so that scan specific modifications should not impact
+ // Create a local copy of hiveconf,so that scan specific modifications should not impact
// other queries
@transient
private[this] val hiveExtraConf = new HiveConf(context.hiveconf)
@@ -72,7 +72,7 @@ case class HiveTableScan(
addColumnMetadataToConf(hiveExtraConf)
@transient
- private[this] val hadoopReader =
+ private[this] val hadoopReader =
new HadoopTableReader(attributes, relation, context, hiveExtraConf)
private[this] def castFromString(value: String, dataType: DataType) = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 6f27a8626fc1e..fd623370cc407 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -62,7 +62,7 @@ case class ScriptTransformation(
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val reader = new BufferedReader(new InputStreamReader(inputStream))
-
+
val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output)
val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors {
@@ -95,7 +95,7 @@ case class ScriptTransformation(
val raw = outputSerde.deserialize(writable)
val dataList = outputSoi.getStructFieldsDataAsList(raw)
val fieldList = outputSoi.getAllStructFieldRefs()
-
+
var i = 0
dataList.foreach( element => {
if (element == null) {
@@ -117,7 +117,7 @@ case class ScriptTransformation(
if (!hasNext) {
throw new NoSuchElementException
}
-
+
if (outputSerde == null) {
val prevLine = curLine
curLine = reader.readLine()
@@ -192,7 +192,7 @@ case class HiveScriptIOSchema (
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
-
+
def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = {
val (columns, columnTypes) = parseAttrs(input)
val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps)
@@ -206,13 +206,13 @@ case class HiveScriptIOSchema (
}
def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
-
+
val columns = attrs.map {
case aref: AttributeReference => aref.name
case e: NamedExpression => e.name
case _ => null
}
-
+
val columnTypes = attrs.map {
case aref: AttributeReference => aref.dataType
case e: NamedExpression => e.dataType
@@ -221,7 +221,7 @@ case class HiveScriptIOSchema (
(columns, columnTypes)
}
-
+
def initSerDe(serdeClassName: String, columns: Seq[String],
columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = {
@@ -240,7 +240,7 @@ case class HiveScriptIOSchema (
(kv._1.split("'")(1), kv._2.split("'")(1))
}).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
-
+
val properties = new Properties()
properties.putAll(propsMap)
serde.initialize(null, properties)
@@ -261,7 +261,7 @@ case class HiveScriptIOSchema (
null
}
}
-
+
def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = {
if (outputSerde != null) {
outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index bb116e3ab7de7..1658bb93b0b79 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre
type UDFType = UDF
+ override def deterministic: Boolean = isUDFDeterministic
+
override def nullable: Boolean = true
@transient
@@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
extends Expression with HiveInspectors with Logging {
type UDFType = GenericUDF
+ override def deterministic: Boolean = isUDFDeterministic
+
override def nullable: Boolean = true
@transient
@@ -555,12 +559,12 @@ private[hive] case class HiveUdafFunction(
} else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}
-
+
private val inspectors = exprs.map(toInspector).toArray
-
- private val function = {
+
+ private val function = {
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
- resolver.getEvaluator(parameterInfo)
+ resolver.getEvaluator(parameterInfo)
}
private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
@@ -575,7 +579,7 @@ private[hive] case class HiveUdafFunction(
@transient
protected lazy val cached = new Array[AnyRef](exprs.length)
-
+
def update(input: Row): Unit = {
val inputs = inputProjection(input)
function.iterate(buffer, wrap(inputs, inspectors, cached))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 945596db80326..39d315aaeab57 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -57,7 +57,7 @@ class CachedTableSuite extends QueryTest {
checkAnswer(
sql("SELECT * FROM src s"),
preCacheResults)
-
+
uncacheTable("src")
assertCached(sql("SELECT * FROM src"), 0)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 80c2d32bf70d7..df137e7b2b333 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -26,12 +26,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.io.LongWritable
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Literal, Row}
import org.apache.spark.sql.types._
-class HiveInspectorSuite extends FunSuite with HiveInspectors {
+class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
test("Test wrap SettableStructObjectInspector") {
val udaf = new UDAFPercentile.PercentileLongEvaluator()
udaf.init()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index fa8e11ffec2b4..e9bb32667936c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.hive
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.hive.test.TestHive
-import org.scalatest.FunSuite
import org.apache.spark.sql.test.ExamplePointUDT
import org.apache.spark.sql.types.StructType
-class HiveMetastoreCatalogSuite extends FunSuite {
+class HiveMetastoreCatalogSuite extends SparkFunSuite {
test("struct field should accept underscore in sub-column name") {
val metastr = "struct"
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index 941a2941649b8..f765395e148af 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.hive
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable}
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-class HiveQlSuite extends FunSuite with BeforeAndAfterAll {
+class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
override def beforeAll() {
if (SessionState.get() == null) {
SessionState.start(new HiveConf())
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 9cc4685499f19..aa5dbe2db6903 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -240,7 +240,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
checkAnswer(sql("select key,value from table_with_partition where ds='1' "),
testData.collect().toSeq
)
-
+
// test difference type of field
sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT")
checkAnswer(sql("select key,value from table_with_partition where ds='1' "),
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
index 8afe5459d4f1b..a492ecf203d17 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
@@ -17,13 +17,11 @@
package org.apache.spark.sql.hive
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.hive.test.TestHive
-class SerializationSuite extends FunSuite {
+class SerializationSuite extends SparkFunSuite {
test("[SPARK-5840] HiveContext should be serializable") {
val hiveContext = TestHive
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 321dc8d7322b8..7eb4842726665 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -17,18 +17,17 @@
package org.apache.spark.sql.hive.client
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.util.Utils
-import org.scalatest.FunSuite
/**
- * A simple set of tests that call the methods of a hive ClientInterface, loading different version
- * of hive from maven central. These tests are simple in that they are mostly just testing to make
- * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity
+ * A simple set of tests that call the methods of a hive ClientInterface, loading different version
+ * of hive from maven central. These tests are simple in that they are mostly just testing to make
+ * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality
* is not fully tested.
*/
-class VersionsSuite extends FunSuite with Logging {
+class VersionsSuite extends SparkFunSuite with Logging {
private def buildConf() = {
lazy val warehousePath = Utils.createTempDir()
lazy val metastorePath = Utils.createTempDir()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
index 23ece7e7cf6e9..b0d3dd44daedc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.hive.test.TestHiveContext
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll {
+class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll {
ignore("multiple instances not supported") {
test("Multiple Hive Instances") {
(1 to 10).map { i =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 55e5551b63818..c9dd4c0935a72 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql.hive.execution
import java.io._
-import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
+import org.scalatest.{BeforeAndAfterAll, GivenWhenThen}
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.sources.DescribeCommand
import org.apache.spark.sql.execution.{SetCommand, ExplainCommand}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
@@ -40,7 +40,7 @@ import org.apache.spark.sql.hive.test.TestHive
* configured using system properties.
*/
abstract class HiveComparisonTest
- extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging {
+ extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging {
/**
* When set, any cache files that result in test failures will be deleted. Used when the test
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 4af31d482ce42..440b7c87b0da2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -57,7 +57,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
sql(
"""
- |CREATE TEMPORARY FUNCTION udtf_count2
+ |CREATE TEMPORARY FUNCTION udtf_count2
|AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
""".stripMargin)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 0ba4d11478211..2209fc2f30a3c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -61,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest {
TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect()
TestHive.sql("drop table tb")
}
-
+
test("Spark-4077: timestamp query for null value") {
TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null")
TestHive.sql(
@@ -71,11 +71,11 @@ class HiveTableScanSuite extends HiveComparisonTest {
FIELDS TERMINATED BY ','
LINES TERMINATED BY '\n'
""".stripMargin)
- val location =
+ val location =
Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile()
-
+
TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null")
- assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect()
+ assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect()
=== Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null)))
TestHive.sql("DROP TABLE timestamp_query_null")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 7f49eac490572..ce5985888f540 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -101,7 +101,7 @@ class HiveUdfSuite extends QueryTest {
sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg")
TestHive.reset()
}
-
+
test("SPARK-2693 udaf aggregates test") {
checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 27863a60145d7..aba3becb1bce2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -780,6 +780,42 @@ class SQLQuerySuite extends QueryTest {
).map(i => Row(i._1, i._2, i._3, i._4)))
}
+ test("window function: multiple window expressions in a single expression") {
+ val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+ nums.registerTempTable("nums")
+
+ val expected =
+ Row(1, 1, 1, 55, 1, 57) ::
+ Row(0, 2, 3, 55, 2, 60) ::
+ Row(1, 3, 6, 55, 4, 65) ::
+ Row(0, 4, 10, 55, 6, 71) ::
+ Row(1, 5, 15, 55, 9, 79) ::
+ Row(0, 6, 21, 55, 12, 88) ::
+ Row(1, 7, 28, 55, 16, 99) ::
+ Row(0, 8, 36, 55, 20, 111) ::
+ Row(1, 9, 45, 55, 25, 125) ::
+ Row(0, 10, 55, 55, 30, 140) :: Nil
+
+ val actual = sql(
+ """
+ |SELECT
+ | y,
+ | x,
+ | sum(x) OVER w1 AS running_sum,
+ | sum(x) OVER w2 AS total_sum,
+ | sum(x) OVER w3 AS running_sum_per_y,
+ | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2
+ |FROM nums
+ |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW),
+ | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING),
+ | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+ """.stripMargin)
+
+ checkAnswer(actual, expected)
+
+ dropTempTable("nums")
+ }
+
test("test case key when") {
(1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
checkAnswer(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
index 88c99e35260d9..0e63d84e9824a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql.hive.orc
import java.io.File
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.util.Utils
-import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
+import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
@@ -38,7 +39,7 @@ case class OrcParData(intField: Int, stringField: String)
case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot
-class OrcPartitionDiscoverySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
+class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll {
val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal
def withTempDir(f: File => Unit): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index cdd6e705f4a2c..57c23fe77f8b5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -21,8 +21,9 @@ import java.io.File
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.ql.io.orc.CompressionKind
-import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
+import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.hive.test.TestHive
@@ -50,7 +51,7 @@ case class Contact(name: String, phone: String)
case class Person(name: String, age: Int, contacts: Seq[Contact])
-class OrcQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll with OrcTest {
+class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
override val sqlContext = TestHive
import TestHive.read
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index cf5ae88dc4bee..74095426741e3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.sources
+import java.io.File
+
+import com.google.common.io.Files
import org.apache.hadoop.fs.Path
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.hive.test.TestHive
@@ -454,6 +456,20 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
}
}
+
+ test("SPARK-7616: adjust column name order accordingly when saving partitioned table") {
+ val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
+
+ df.write
+ .format(dataSourceName)
+ .mode(SaveMode.Overwrite)
+ .partitionBy("c", "a")
+ .saveAsTable("t")
+
+ withTable("t") {
+ checkAnswer(table("t"), df.select('b, 'c, 'a).collect())
+ }
+ }
}
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
@@ -485,7 +501,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
-class CommitFailureTestRelationSuite extends FunSuite with SQLTestUtils {
+class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
import TestHive.implicits._
override val sqlContext = TestHive
@@ -535,20 +551,6 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
- test("SPARK-7616: adjust column name order accordingly when saving partitioned table") {
- val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c")
-
- df.write
- .format("parquet")
- .mode(SaveMode.Overwrite)
- .partitionBy("c", "a")
- .saveAsTable("t")
-
- withTable("t") {
- checkAnswer(table("t"), df.select('b, 'c, 'a).collect())
- }
- }
-
test("SPARK-7868: _temporary directories should be ignored") {
withTempPath { dir =>
val df = Seq("a", "b", "c").zipWithIndex.toDF()
@@ -564,4 +566,32 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect())
}
}
+
+ test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") {
+ withTempDir { dir =>
+ val path = dir.getCanonicalPath
+ val df = Seq(1 -> "a").toDF()
+
+ // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw
+ // since it's not a valid Parquet file.
+ val emptyFile = new File(path, "empty")
+ Files.createParentDirs(emptyFile)
+ Files.touch(emptyFile)
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Ignore).save(path)
+
+ // This should only complain that the destination directory already exists, rather than file
+ // "empty" is not a Parquet file.
+ assert {
+ intercept[RuntimeException] {
+ df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path)
+ }.getMessage.contains("already exists")
+ }
+
+ // This shouldn't throw anything.
+ df.write.format("parquet").mode(SaveMode.Overwrite).save(path)
+ checkAnswer(read.format("parquet").load(path), df)
+ }
+ }
}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 5ab7f4472c38b..49d035a1e9696 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 25842d502543e..9cd9684d36404 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import scala.collection.Map
import scala.collection.mutable.Queue
import scala.reflect.ClassTag
+import scala.util.control.NonFatal
import akka.actor.{Props, SupervisorStrategy}
import org.apache.hadoop.conf.Configuration
@@ -270,6 +271,8 @@ class StreamingContext private[streaming] (
* Create an input stream with any arbitrary user implemented receiver.
* Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html
* @param receiver Custom implementation of Receiver
+ *
+ * @deprecated As of 1.0.0", replaced by `receiverStream`.
*/
@deprecated("Use receiverStream", "1.0.0")
def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = {
@@ -576,18 +579,26 @@ class StreamingContext private[streaming] (
def start(): Unit = synchronized {
state match {
case INITIALIZED =>
- validate()
startSite.set(DStream.getCreationSite())
sparkContext.setCallSite(startSite.get)
StreamingContext.ACTIVATION_LOCK.synchronized {
StreamingContext.assertNoOtherContextIsActive()
- scheduler.start()
- uiTab.foreach(_.attach())
- state = StreamingContextState.ACTIVE
+ try {
+ validate()
+ scheduler.start()
+ state = StreamingContextState.ACTIVE
+ } catch {
+ case NonFatal(e) =>
+ logError("Error starting the context, marking it as stopped", e)
+ scheduler.stop(false)
+ state = StreamingContextState.STOPPED
+ throw e
+ }
StreamingContext.setActiveContext(this)
}
shutdownHookRef = Utils.addShutdownHook(
StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown)
+ uiTab.foreach(_.attach())
logInfo("StreamingContext started")
case ACTIVE =>
logWarning("StreamingContext has already been started")
@@ -608,6 +619,8 @@ class StreamingContext private[streaming] (
* Wait for the execution to stop. Any exceptions that occurs during the execution
* will be thrown in this thread.
* @param timeout time to wait in milliseconds
+ *
+ * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`.
*/
@deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0")
def awaitTermination(timeout: Long) {
@@ -732,6 +745,10 @@ object StreamingContext extends Logging {
}
}
+ /**
+ * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object.
+ * This is kept here only for backward compatibility.
+ */
@deprecated("Replaced by implicit functions in the DStream companion object. This is " +
"kept here only for backward compatibility.", "1.3.0")
def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index b639b94d5ca47..989e3a729ebc2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -148,6 +148,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
/** The underlying SparkContext */
val sparkContext = new JavaSparkContext(ssc.sc)
+ /**
+ * @deprecated As of 0.9.0, replaced by `sparkContext`
+ */
@deprecated("use sparkContext", "0.9.0")
val sc: JavaSparkContext = sparkContext
@@ -619,6 +622,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
* Wait for the execution to stop. Any exceptions that occurs during the execution
* will be thrown in this thread.
* @param timeout time to wait in milliseconds
+ * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`.
*/
@deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0")
def awaitTermination(timeout: Long): Unit = {
@@ -677,6 +681,7 @@ object JavaStreamingContext {
*
* @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
* @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor.
*/
@deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0")
def getOrCreate(
@@ -699,6 +704,7 @@ object JavaStreamingContext {
* @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
* @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
* file system
+ * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor.
*/
@deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
def getOrCreate(
@@ -724,6 +730,7 @@ object JavaStreamingContext {
* file system
* @param createOnError Whether to create a new JavaStreamingContext if there is an
* error in reading checkpoint data.
+ * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor.
*/
@deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
def getOrCreate(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 6efcc193bfccc..192aa6a139bcb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -603,6 +603,8 @@ abstract class DStream[T: ClassTag] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
+ *
+ * @deprecated As of 0.9.0, replaced by `foreachRDD`.
*/
@deprecated("use foreachRDD", "0.9.0")
def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope {
@@ -612,6 +614,8 @@ abstract class DStream[T: ClassTag] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
+ *
+ * @deprecated As of 0.9.0, replaced by `foreachRDD`.
*/
@deprecated("use foreachRDD", "0.9.0")
def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index 0588517a2de39..8d73593ab6375 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -191,7 +191,7 @@ private[streaming] class BlockGenerator(
logError(message, t)
listener.onError(message, t)
}
-
+
private def pushBlock(block: Block) {
listener.onPushBlock(block.id, block.buffer)
logInfo("Pushed block " + block.id)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index 651b534ac1900..207d64d9414ee 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -62,7 +62,7 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI
private[streaming] class BlockManagerBasedBlockHandler(
blockManager: BlockManager, storageLevel: StorageLevel)
extends ReceivedBlockHandler with Logging {
-
+
def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = {
val putResult: Seq[(BlockId, BlockStatus)] = block match {
case ArrayBufferBlock(arrayBuffer) =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 1d1ddaaccf217..4af9b6d3b56ab 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -126,6 +126,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
eventLoop.post(ErrorReported(msg, e))
}
+ def isStarted(): Boolean = synchronized {
+ eventLoop != null
+ }
+
private def processEvent(event: JobSchedulerEvent) {
try {
event match {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
index 6a1dd6949b204..9b5e4dc819a2b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.streaming
import java.io.NotSerializableException
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.{HashPartitioner, SparkContext, SparkException}
+import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.ReturnStatementInClosureException
@@ -29,7 +29,7 @@ import org.apache.spark.util.ReturnStatementInClosureException
/**
* Test that closures passed to DStream operations are actually cleaned.
*/
-class DStreamClosureSuite extends FunSuite with BeforeAndAfterAll {
+class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
private var ssc: StreamingContext = null
override def beforeAll(): Unit = {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala
index e3fb2ef130859..8844c9d74b933 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.streaming
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDDOperationScope
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.ui.UIUtils
@@ -27,7 +27,7 @@ import org.apache.spark.streaming.ui.UIUtils
/**
* Tests whether scope information is passed from DStream operations to RDDs correctly.
*/
-class DStreamScopeSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll {
+class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
private var ssc: StreamingContext = null
private val batchDuration: Duration = Seconds(1)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 23804237bda80..cca8cedb1d080 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -25,7 +25,7 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.hadoop.conf.Configuration
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
@@ -41,7 +41,11 @@ import org.apache.spark.util.{ManualClock, Utils}
import WriteAheadLogBasedBlockHandler._
import WriteAheadLogSuite._
-class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
+class ReceivedBlockHandlerSuite
+ extends SparkFunSuite
+ with BeforeAndAfter
+ with Matchers
+ with Logging {
val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
val hadoopConf = new Configuration()
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index b1af8d5eaacfb..6f0ee774cb5cf 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps}
import scala.util.Random
import org.apache.hadoop.conf.Configuration
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
import org.apache.spark.streaming.scheduler._
@@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._
import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
class ReceivedBlockTrackerSuite
- extends FunSuite with BeforeAndAfter with Matchers with Logging {
+ extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
val hadoopConf = new Configuration()
val akkaTimeout = 10 seconds
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index e36c7914b130e..819dd2ccfe915 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -25,16 +25,16 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
-import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
+import org.scalatest.{Assertions, BeforeAndAfter}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.Utils
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite}
-class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
+class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging {
val master = "local[2]"
val appName = this.getClass.getSimpleName
@@ -151,6 +151,22 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
assert(StreamingContext.getActive().isEmpty)
}
+ test("start failure should stop internal components") {
+ ssc = new StreamingContext(conf, batchDuration)
+ val inputStream = addInputStream(ssc)
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.sum + state.getOrElse(0))
+ }
+ inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc)
+ // Require that the start fails because checkpoint directory was not set
+ intercept[Exception] {
+ ssc.start()
+ }
+ assert(ssc.getState() === StreamingContextState.STOPPED)
+ assert(ssc.scheduler.isStarted === false)
+ }
+
+
test("start multiple times") {
ssc = new StreamingContext(master, appName, batchDuration)
addInputStream(ssc).register()
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 554cd30223f44..31b1aebf6a8ec 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer
import scala.language.implicitConversions
import scala.reflect.ClassTag
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import org.scalatest.time.{Span, Seconds => ScalaTestSeconds}
import org.scalatest.concurrent.Eventually.timeout
import org.scalatest.concurrent.PatienceConfiguration
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
import org.apache.spark.streaming.scheduler._
@@ -204,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) {
* This is the base trait for Spark Streaming testsuites. This provides basic functionality
* to run user-defined set of input on user-defined stream operations, and verify the output.
*/
-trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
+trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
// Name of the framework for Spark context
def framework: String = this.getClass.getSimpleName
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
index 441bbf95d0153..cbc24aee4fa1e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
@@ -28,14 +28,11 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._
-
-
-
/**
* Selenium tests for the Spark Web UI.
*/
class UISeleniumSuite
- extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase {
+ extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase {
implicit var webDriver: WebDriver = _
@@ -197,4 +194,3 @@ class UISeleniumSuite
}
}
}
-
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index 6859b65c7165f..cb017b798b2a4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -21,15 +21,15 @@ import java.io.File
import scala.util.Random
import org.apache.hadoop.conf.Configuration
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter}
import org.apache.spark.util.Utils
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
class WriteAheadLogBackedBlockRDDSuite
- extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
+ extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
val conf = new SparkConf()
.setMaster("local[2]")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
index 5478b41845943..2e210397fe7c7 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.streaming.scheduler
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.{Time, Duration, StreamingContext}
-class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter {
+class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter {
private var ssc: StreamingContext = _
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala
index e9ab917ab845c..d3ca2b58f36c2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.streaming.ui
import java.util.TimeZone
import java.util.concurrent.TimeUnit
-import org.scalatest.FunSuite
import org.scalatest.Matchers
-class UIUtilsSuite extends FunSuite with Matchers{
+import org.apache.spark.SparkFunSuite
+
+class UIUtilsSuite extends SparkFunSuite with Matchers{
test("shortTimeUnitString") {
assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
index 9ebf7b484f421..78fc344b00177 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.streaming.util
import java.io.ByteArrayOutputStream
import java.util.concurrent.TimeUnit._
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class RateLimitedOutputStreamSuite extends FunSuite {
+class RateLimitedOutputStreamSuite extends SparkFunSuite {
private def benchmark[U](f: => U): Long = {
val start = System.nanoTime
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 79098bcf4861c..325ff7c74c39d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -28,15 +28,15 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.scalatest.concurrent.Eventually._
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import org.apache.spark.util.{ManualClock, Utils}
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
-class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
+class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter {
import WriteAheadLogSuite._
-
+
val hadoopConf = new Configuration()
var tempDir: File = null
var testDir: String = null
@@ -359,7 +359,7 @@ object WriteAheadLogSuite {
): FileBasedWriteAheadLog = {
if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000)
val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1)
-
+
// Ensure that 500 does not get sorted after 2000, so put a high base value.
data.foreach { item =>
manualClock.advance(500)
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 00d219f836708..e207a46809684 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -39,6 +39,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.hadoophadoop-yarn-api
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
index aaae6f9734a85..77af46c192cc2 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala
@@ -60,8 +60,11 @@ private[yarn] class AMDelegationTokenRenewer(
private val hadoopUtil = YarnSparkHadoopUtil.get
- private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5)
- private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5)
+ private val credentialsFile = sparkConf.get("spark.yarn.credentials.file")
+ private val daysToKeepFiles =
+ sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5)
+ private val numFilesToKeep =
+ sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5)
/**
* Schedule a login from the keytab and principal set using the --principal and --keytab
@@ -121,7 +124,7 @@ private[yarn] class AMDelegationTokenRenewer(
import scala.concurrent.duration._
try {
val remoteFs = FileSystem.get(hadoopConf)
- val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file"))
+ val credentialsPath = new Path(credentialsFile)
val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis
hadoopUtil.listFilesSorted(
remoteFs, credentialsPath.getParent,
@@ -160,7 +163,7 @@ private[yarn] class AMDelegationTokenRenewer(
val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab)
logInfo("Successfully logged into KDC.")
val tempCreds = keytabLoggedInUGI.getCredentials
- val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file"))
+ val credentialsPath = new Path(credentialsFile)
val dst = credentialsPath.getParent
keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] {
// Get a copy of the credentials
@@ -186,8 +189,7 @@ private[yarn] class AMDelegationTokenRenewer(
}
val nextSuffix = lastCredentialsFileSuffix + 1
val tokenPathStr =
- sparkConf.get("spark.yarn.credentials.file") +
- SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix
+ credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix
val tokenPath = new Path(tokenPathStr)
val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
logInfo("Writing out delegation tokens to " + tempTokenPath.toString)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 5653c9f14dc6d..9c7b1b3988082 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
numExecutors = initialNumExecutors
}
+ principal = Option(principal)
+ .orElse(sparkConf.getOption("spark.yarn.principal"))
+ .orNull
+ keytab = Option(keytab)
+ .orElse(sparkConf.getOption("spark.yarn.keytab"))
+ .orNull
}
/**
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
index 4ca6c903fcf12..3d3a966960e9f 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
* Add a resource to the list of distributed cache resources. This list can
* be sent to the ApplicationMaster and possibly the executors so that it can
* be downloaded into the Hadoop distributed cache for use by this application.
- * Adds the LocalResource to the localResources HashMap passed in and saves
+ * Adds the LocalResource to the localResources HashMap passed in and saves
* the stats of the resources to they can be sent to the executors and verified.
*
* @param fs FileSystem
* @param conf Configuration
* @param destPath path to the resource
* @param localResources localResource hashMap to insert the resource into
- * @param resourceType LocalResourceType
+ * @param resourceType LocalResourceType
* @param link link presented in the distributed cache to the destination
- * @param statCache cache to store the file/directory stats
+ * @param statCache cache to store the file/directory stats
* @param appMasterOnly Whether to only add the resource to the app master
*/
def addResource(
fs: FileSystem,
conf: Configuration,
- destPath: Path,
+ destPath: Path,
localResources: HashMap[String, LocalResource],
resourceType: LocalResourceType,
link: String,
@@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
amJarRsrc.setSize(destStatus.getLen())
if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name")
localResources(link) = amJarRsrc
-
+
if (!appMasterOnly) {
val uri = destPath.toUri()
val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
if (resourceType == LocalResourceType.FILE) {
- distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
+ distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
destStatus.getModificationTime().toString(), visibility.name())
} else {
- distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
+ distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
destStatus.getModificationTime().toString(), visibility.name())
}
}
@@ -96,11 +96,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
val (sizes, timeStamps, visibilities) = tupleValues.unzip3
if (keys.size > 0) {
env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
sizes.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
+ env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
visibilities.reduceLeft[String] { (acc, n) => acc + "," + n }
}
}
@@ -113,11 +113,11 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
val (sizes, timeStamps, visibilities) = tupleValues.unzip3
if (keys.size > 0) {
env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n }
env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") =
sizes.reduceLeft[String] { (acc, n) => acc + "," + n }
- env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
+ env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
visibilities.reduceLeft[String] { (acc, n) => acc + "," + n }
}
}
@@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging {
def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = {
val stat = statCache.get(uri) match {
case Some(existstat) => existstat
- case None =>
+ case None =>
val newStat = fs.getFileStatus(new Path(uri))
statCache.put(uri, newStat)
newStat
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 5e6531895c7ba..68d01c17ef720 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -144,9 +144,9 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
}
object YarnSparkHadoopUtil {
- // Additional memory overhead
+ // Additional memory overhead
// 10% was arrived at experimentally. In the interest of minimizing memory waste while covering
- // the common cases. Memory overhead tends to grow with container size.
+ // the common cases. Memory overhead tends to grow with container size.
val MEMORY_OVERHEAD_FACTOR = 0.10
val MEMORY_OVERHEAD_MIN = 384
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
index 80b57d1355a3a..804dfecde7867 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn
import java.net.URI
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito.when
@@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
import scala.collection.mutable.HashMap
import scala.collection.mutable.Map
+import org.apache.spark.SparkFunSuite
-class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
+
+class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar {
class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
- override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
LocalResourceVisibility = {
LocalResourceVisibility.PRIVATE
}
}
-
+
test("test getFileStatus empty") {
val distMgr = new ClientDistributedCacheManager()
val fs = mock[FileSystem]
@@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val distMgr = new ClientDistributedCacheManager()
val fs = mock[FileSystem]
val uri = new URI("/tmp/testing")
- val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
null, new Path("/tmp/testing"))
when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
@@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
statCache, false)
val resource = localResources("link")
assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
@@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
// add another one and verify both there and order correct
- val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
null, new Path("/tmp/testing2"))
val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
- distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
statCache, false)
val resource2 = localResources("link2")
assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
@@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val env2 = new HashMap[String, String]()
distMgr.setDistFilesEnv(env2)
val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
- val files = env2("SPARK_YARN_CACHE_FILES").split(',')
+ val files = env2("SPARK_YARN_CACHE_FILES").split(',')
val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',')
assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
@@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
intercept[Exception] {
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
statCache, false)
}
assert(localResources.get("link") === None)
@@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
val localResources = HashMap[String, LocalResource]()
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
- val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
null, new Path("/tmp/testing"))
when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
statCache, true)
val resource = localResources("link")
assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
@@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
val localResources = HashMap[String, LocalResource]()
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
- val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
null, new Path("/tmp/testing"))
when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
- distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
statCache, false)
val resource = localResources("link")
assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 6da3e82acdb14..01d33c9ce9297 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.mockito.Matchers._
import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfterAll, Matchers}
-import org.apache.spark.{SparkException, SparkConf}
+import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.util.Utils
-class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll {
+class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll {
override def beforeAll(): Unit = {
System.setProperty("SPARK_YARN_MODE", "true")
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index b343cbb0c7569..7509000771d94 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
-import org.apache.spark.SecurityManager
+import org.apache.spark.{SecurityManager, SparkFunSuite}
import org.apache.spark.SparkConf
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
import org.apache.spark.deploy.yarn.YarnAllocator._
import org.apache.spark.scheduler.SplitInfo
-import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfterEach, Matchers}
class MockResolver extends DNSToSwitchMapping {
@@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping {
def reloadCachedMappings(names: JList[String]) {}
}
-class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach {
+class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
val conf = new Configuration()
conf.setClass(
CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index dcaeb2e43ff41..d8bc2534c1a6a 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -30,9 +30,9 @@ import com.google.common.io.ByteStreams
import com.google.common.io.Files
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.server.MiniYARNCluster
-import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfterAll, Matchers}
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils}
+import org.apache.spark._
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
SparkListenerExecutorAdded}
@@ -43,7 +43,7 @@ import org.apache.spark.util.Utils
* applications, and require the Spark assembly to be built before they can be successfully
* run.
*/
-class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging {
+class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging {
// log4j configuration for the YARN containers, so that their output is collected
// by YARN instead of trying to overwrite unit-tests.log.
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
index e10b985c3c236..49bee0866dd43 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.yarn.api.ApplicationConstants
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
import org.apache.hadoop.yarn.api.records.ApplicationAccessType
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.util.Utils
-class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging {
+class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging {
val hasBash =
try {