diff --git a/.rat-excludes b/.rat-excludes
index c0f81b57fe09d..994c7e86f8a91 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -80,5 +80,8 @@ local-1425081759269/*
local-1426533911241/*
local-1426633911242/*
local-1430917381534/*
+local-1430917381535_1
+local-1430917381535_2
DESCRIPTION
NAMESPACE
+test_support/*
diff --git a/LICENSE b/LICENSE
index 9d1b00beff748..d0cd0dcb4bdb7 100644
--- a/LICENSE
+++ b/LICENSE
@@ -853,6 +853,52 @@ and
Vis.js may be distributed under either license.
+========================================================================
+For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js):
+========================================================================
+Copyright (c) 2013 Chris Pettitt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+========================================================================
+For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js):
+========================================================================
+Copyright (c) 2012-2013 Chris Pettitt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/R/create-docs.sh b/R/create-docs.sh
index 4194172a2e115..6a4687b06ecb9 100755
--- a/R/create-docs.sh
+++ b/R/create-docs.sh
@@ -23,14 +23,14 @@
# After running this script the html docs can be found in
# $SPARK_HOME/R/pkg/html
+set -o pipefail
+set -e
+
# Figure out where the script is
export FWDIR="$(cd "`dirname "$0"`"; pwd)"
pushd $FWDIR
-# Generate Rd file
-Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))'
-
-# Install the package
+# Install the package (this will also generate the Rd files)
./install-dev.sh
# Now create HTML files
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 55ed6f4be1a4a..1edd551f8d243 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -26,11 +26,20 @@
# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory
# to load the SparkR package on the worker nodes.
+set -o pipefail
+set -e
FWDIR="$(cd `dirname $0`; pwd)"
LIB_DIR="$FWDIR/lib"
mkdir -p $LIB_DIR
-# Install R
+pushd $FWDIR
+
+# Generate Rd files if devtools is installed
+Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }'
+
+# Install SparkR to $LIB_DIR
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
+
+popd
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 411126a377950..f9447f6c3288d 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -19,9 +19,11 @@ exportMethods("arrange",
"count",
"describe",
"distinct",
+ "dropna",
"dtypes",
"except",
"explain",
+ "fillna",
"filter",
"first",
"group_by",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index ed8093c80d360..0af5cb8881e35 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1314,9 +1314,8 @@ setMethod("except",
#' write.df(df, "myfile", "parquet", "overwrite")
#' }
setMethod("write.df",
- signature(df = "DataFrame", path = 'character', source = 'character',
- mode = 'character'),
- function(df, path = NULL, source = NULL, mode = "append", ...){
+ signature(df = "DataFrame", path = 'character'),
+ function(df, path, source = NULL, mode = "append", ...){
if (is.null(source)) {
sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
@@ -1338,9 +1337,8 @@ setMethod("write.df",
#' @aliases saveDF
#' @export
setMethod("saveDF",
- signature(df = "DataFrame", path = 'character', source = 'character',
- mode = 'character'),
- function(df, path = NULL, source = NULL, mode = "append", ...){
+ signature(df = "DataFrame", path = 'character'),
+ function(df, path, source = NULL, mode = "append", ...){
write.df(df, path, source, mode, ...)
})
@@ -1431,3 +1429,128 @@ setMethod("describe",
sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
dataFrame(sdf)
})
+
+#' dropna
+#'
+#' Returns a new DataFrame omitting rows with null values.
+#'
+#' @param x A SparkSQL DataFrame.
+#' @param how "any" or "all".
+#' if "any", drop a row if it contains any nulls.
+#' if "all", drop a row only if all its values are null.
+#' if minNonNulls is specified, how is ignored.
+#' @param minNonNulls If specified, drop rows that have less than
+#' minNonNulls non-null values.
+#' This overwrites the how parameter.
+#' @param cols Optional list of column names to consider.
+#' @return A DataFrame
+#'
+#' @rdname nafunctions
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' dropna(df)
+#' }
+setMethod("dropna",
+ signature(x = "DataFrame"),
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ how <- match.arg(how)
+ if (is.null(cols)) {
+ cols <- columns(x)
+ }
+ if (is.null(minNonNulls)) {
+ minNonNulls <- if (how == "any") { length(cols) } else { 1 }
+ }
+
+ naFunctions <- callJMethod(x@sdf, "na")
+ sdf <- callJMethod(naFunctions, "drop",
+ as.integer(minNonNulls), listToSeq(as.list(cols)))
+ dataFrame(sdf)
+ })
+
+#' @aliases dropna
+#' @export
+setMethod("na.omit",
+ signature(x = "DataFrame"),
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ dropna(x, how, minNonNulls, cols)
+ })
+
+#' fillna
+#'
+#' Replace null values.
+#'
+#' @param x A SparkSQL DataFrame.
+#' @param value Value to replace null values with.
+#' Should be an integer, numeric, character or named list.
+#' If the value is a named list, then cols is ignored and
+#' value must be a mapping from column name (character) to
+#' replacement value. The replacement value must be an
+#' integer, numeric or character.
+#' @param cols optional list of column names to consider.
+#' Columns specified in cols that do not have matching data
+#' type are ignored. For example, if value is a character, and
+#' subset contains a non-character column, then the non-character
+#' column is simply ignored.
+#' @return A DataFrame
+#'
+#' @rdname nafunctions
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' fillna(df, 1)
+#' fillna(df, list("age" = 20, "name" = "unknown"))
+#' }
+setMethod("fillna",
+ signature(x = "DataFrame"),
+ function(x, value, cols = NULL) {
+ if (!(class(value) %in% c("integer", "numeric", "character", "list"))) {
+ stop("value should be an integer, numeric, charactor or named list.")
+ }
+
+ if (class(value) == "list") {
+ # Check column names in the named list
+ colNames <- names(value)
+ if (length(colNames) == 0 || !all(colNames != "")) {
+ stop("value should be an a named list with each name being a column name.")
+ }
+
+ # Convert to the named list to an environment to be passed to JVM
+ valueMap <- new.env()
+ for (col in colNames) {
+ # Check each item in the named list is of valid type
+ v <- value[[col]]
+ if (!(class(v) %in% c("integer", "numeric", "character"))) {
+ stop("Each item in value should be an integer, numeric or charactor.")
+ }
+ valueMap[[col]] <- v
+ }
+
+ # When value is a named list, caller is expected not to pass in cols
+ if (!is.null(cols)) {
+ warning("When value is a named list, cols is ignored!")
+ cols <- NULL
+ }
+
+ value <- valueMap
+ } else if (is.integer(value)) {
+ # Cast an integer to a numeric
+ value <- as.numeric(value)
+ }
+
+ naFunctions <- callJMethod(x@sdf, "na")
+ sdf <- if (length(cols) == 0) {
+ callJMethod(naFunctions, "fill", value)
+ } else {
+ callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
+ }
+ dataFrame(sdf)
+ })
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 36cc612875879..22a4b5bf86ebd 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -452,20 +452,31 @@ dropTempTable <- function(sqlContext, tableName) {
#' df <- read.df(sqlContext, "path/to/file.json", source = "json")
#' }
-read.df <- function(sqlContext, path = NULL, source = NULL, ...) {
+read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
options <- varargsToEnv(...)
if (!is.null(path)) {
options[['path']] <- path
}
- sdf <- callJMethod(sqlContext, "load", source, options)
+ if (is.null(source)) {
+ sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv)
+ source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ }
+ if (!is.null(schema)) {
+ stopifnot(class(schema) == "structType")
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source,
+ schema$jobj, options)
+ } else {
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options)
+ }
dataFrame(sdf)
}
#' @aliases loadDF
#' @export
-loadDF <- function(sqlContext, path = NULL, source = NULL, ...) {
- read.df(sqlContext, path, source, ...)
+loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) {
+ read.df(sqlContext, path, source, schema, ...)
}
#' Create an external table
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index a23d3b217b2fd..12e09176c9f92 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -396,6 +396,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") })
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
+#' @rdname nafunctions
+#' @export
+setGeneric("dropna",
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("dropna")
+ })
+
+#' @rdname nafunctions
+#' @export
+setGeneric("na.omit",
+ function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) {
+ standardGeneric("na.omit")
+ })
+
#' @rdname schema
#' @export
setGeneric("dtypes", function(x) { standardGeneric("dtypes") })
@@ -408,6 +422,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") })
#' @export
setGeneric("except", function(x, y) { standardGeneric("except") })
+#' @rdname nafunctions
+#' @export
+setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") })
+
#' @rdname filter
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
@@ -482,11 +500,11 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) {
#' @rdname write.df
#' @export
-setGeneric("write.df", function(df, path, source, mode, ...) { standardGeneric("write.df") })
+setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
#' @rdname write.df
#' @export
-setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") })
+setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") })
#' @rdname schema
#' @export
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index c53d0a961016f..3169d7968f8fe 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -37,6 +37,14 @@ writeObject <- function(con, object, writeType = TRUE) {
# passing in vectors as arrays and instead require arrays to be passed
# as lists.
type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt")
+ # Checking types is needed here, since ‘is.na’ only handles atomic vectors,
+ # lists and pairlists
+ if (type %in% c("integer", "character", "logical", "double", "numeric")) {
+ if (is.na(object)) {
+ object <- NULL
+ type <- "NULL"
+ }
+ }
if (writeType) {
writeType(con, type)
}
@@ -160,6 +168,14 @@ writeList <- function(con, arr) {
}
}
+# Used to pass arrays where the elements can be of different types
+writeGenericList <- function(con, list) {
+ writeInt(con, length(list))
+ for (elem in list) {
+ writeObject(con, elem)
+ }
+}
+
# Used to pass in hash maps required on Java side.
writeEnv <- function(con, env) {
len <- length(env)
@@ -168,7 +184,7 @@ writeEnv <- function(con, env) {
if (len > 0) {
writeList(con, as.list(ls(env)))
vals <- lapply(ls(env), function(x) { env[[x]] })
- writeList(con, as.list(vals))
+ writeGenericList(con, as.list(vals))
}
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 68387f0f5365d..5ced7c688f98a 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -225,14 +225,21 @@ sparkR.init <- function(
#' sqlContext <- sparkRSQL.init(sc)
#'}
-sparkRSQL.init <- function(jsc) {
+sparkRSQL.init <- function(jsc = NULL) {
if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
return(get(".sparkRSQLsc", envir = .sparkREnv))
}
+ # If jsc is NULL, create a Spark Context
+ sc <- if (is.null(jsc)) {
+ sparkR.init()
+ } else {
+ jsc
+ }
+
sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
- "createSQLContext",
- jsc)
+ "createSQLContext",
+ sc)
assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv)
sqlContext
}
@@ -249,12 +256,19 @@ sparkRSQL.init <- function(jsc) {
#' sqlContext <- sparkRHive.init(sc)
#'}
-sparkRHive.init <- function(jsc) {
+sparkRHive.init <- function(jsc = NULL) {
if (exists(".sparkRHivesc", envir = .sparkREnv)) {
return(get(".sparkRHivesc", envir = .sparkREnv))
}
- ssc <- callJMethod(jsc, "sc")
+ # If jsc is NULL, create a Spark Context
+ sc <- if (is.null(jsc)) {
+ sparkR.init()
+ } else {
+ jsc
+ }
+
+ ssc <- callJMethod(sc, "sc")
hiveCtx <- tryCatch({
newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
}, error = function(err) {
diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R
index ca94f1d4e7fd5..773b6ecf582d9 100644
--- a/R/pkg/inst/profile/shell.R
+++ b/R/pkg/inst/profile/shell.R
@@ -24,7 +24,7 @@
old <- getOption("defaultPackages")
options(defaultPackages = c(old, "SparkR"))
- sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = ""))
+ sc <- SparkR::sparkR.init()
assign("sc", sc, envir=.GlobalEnv)
sqlContext <- SparkR::sparkRSQL.init(sc)
assign("sqlContext", sqlContext, envir=.GlobalEnv)
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 1857e636e8577..8946348ef801c 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -32,6 +32,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet")
writeLines(mockLines, jsonPath)
+# For test nafunctions, like dropna(), fillna(),...
+mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}",
+ "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}",
+ "{\"name\":\"David\",\"age\":60,\"height\":null}",
+ "{\"name\":\"Amy\",\"age\":null,\"height\":null}",
+ "{\"name\":null,\"age\":null,\"height\":null}")
+jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp")
+writeLines(mockLinesNa, jsonPathNa)
+
test_that("infer types", {
expect_equal(infer_type(1L), "integer")
expect_equal(infer_type(1.0), "double")
@@ -92,6 +101,43 @@ test_that("create DataFrame from RDD", {
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
})
+test_that("convert NAs to null type in DataFrames", {
+ rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L)))
+ df <- createDataFrame(sqlContext, rdd, list("a", "b"))
+ expect_true(is.na(collect(df)[2, "a"]))
+ expect_equal(collect(df)[2, "b"], 4L)
+
+ l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L))
+ df <- createDataFrame(sqlContext, l)
+ expect_equal(collect(df)[2, "x"], 1L)
+ expect_true(is.na(collect(df)[2, "y"]))
+
+ rdd <- parallelize(sc, list(list(1, 2), list(NA, 4)))
+ df <- createDataFrame(sqlContext, rdd, list("a", "b"))
+ expect_true(is.na(collect(df)[2, "a"]))
+ expect_equal(collect(df)[2, "b"], 4)
+
+ l <- data.frame(x = 1, y = c(1, NA_real_, 3))
+ df <- createDataFrame(sqlContext, l)
+ expect_equal(collect(df)[2, "x"], 1)
+ expect_true(is.na(collect(df)[2, "y"]))
+
+ l <- list("a", "b", NA, "d")
+ df <- createDataFrame(sqlContext, l)
+ expect_true(is.na(collect(df)[3, "_1"]))
+ expect_equal(collect(df)[4, "_1"], "d")
+
+ l <- list("a", "b", NA_character_, "d")
+ df <- createDataFrame(sqlContext, l)
+ expect_true(is.na(collect(df)[3, "_1"]))
+ expect_equal(collect(df)[4, "_1"], "d")
+
+ l <- list(TRUE, FALSE, NA, TRUE)
+ df <- createDataFrame(sqlContext, l)
+ expect_true(is.na(collect(df)[3, "_1"]))
+ expect_equal(collect(df)[4, "_1"], TRUE)
+})
+
test_that("toDF", {
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
df <- toDF(rdd, list("a", "b"))
@@ -495,6 +541,19 @@ test_that("read.df() from json file", {
df <- read.df(sqlContext, jsonPath, "json")
expect_true(inherits(df, "DataFrame"))
expect_true(count(df) == 3)
+
+ # Check if we can apply a user defined schema
+ schema <- structType(structField("name", type = "string"),
+ structField("age", type = "double"))
+
+ df1 <- read.df(sqlContext, jsonPath, "json", schema)
+ expect_true(inherits(df1, "DataFrame"))
+ expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double")))
+
+ # Run the same with loadDF
+ df2 <- loadDF(sqlContext, jsonPath, "json", schema)
+ expect_true(inherits(df2, "DataFrame"))
+ expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double")))
})
test_that("write.df() as parquet file", {
@@ -765,5 +824,105 @@ test_that("describe() on a DataFrame", {
expect_equal(collect(stats)[5, "age"], "30")
})
+test_that("dropna() on a DataFrame", {
+ df <- jsonFile(sqlContext, jsonPathNa)
+ rows <- collect(df)
+
+ # drop with columns
+
+ expected <- rows[!is.na(rows$name),]
+ actual <- collect(dropna(df, cols = "name"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age),]
+ actual <- collect(dropna(df, cols = "age"))
+ row.names(expected) <- row.names(actual)
+ # identical on two dataframes does not work here. Don't know why.
+ # use identical on all columns as a workaround.
+ expect_true(identical(expected$age, actual$age))
+ expect_true(identical(expected$height, actual$height))
+ expect_true(identical(expected$name, actual$name))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
+ actual <- collect(dropna(df, cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df))
+ expect_true(identical(expected, actual))
+
+ # drop with how
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),]
+ actual <- collect(dropna(df, "all"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),]
+ actual <- collect(dropna(df, "any"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) & !is.na(rows$height),]
+ actual <- collect(dropna(df, "any", cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[!is.na(rows$age) | !is.na(rows$height),]
+ actual <- collect(dropna(df, "all", cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ # drop with threshold
+
+ expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,]
+ actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height")))
+ expect_true(identical(expected, actual))
+
+ expected <- rows[as.integer(!is.na(rows$age)) +
+ as.integer(!is.na(rows$height)) +
+ as.integer(!is.na(rows$name)) >= 3,]
+ actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height")))
+ expect_true(identical(expected, actual))
+})
+
+test_that("fillna() on a DataFrame", {
+ df <- jsonFile(sqlContext, jsonPathNa)
+ rows <- collect(df)
+
+ # fill with value
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ expected$height[is.na(expected$height)] <- 50.6
+ actual <- collect(fillna(df, 50.6))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, "unknown"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ actual <- collect(fillna(df, 50.6, "age"))
+ expect_true(identical(expected, actual))
+
+ expected <- rows
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, "unknown", c("age", "name")))
+ expect_true(identical(expected, actual))
+
+ # fill with named list
+
+ expected <- rows
+ expected$age[is.na(expected$age)] <- 50
+ expected$height[is.na(expected$height)] <- 50.6
+ expected$name[is.na(expected$name)] <- "unknown"
+ actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown")))
+ expect_true(identical(expected, actual))
+})
+
unlink(parquetPath)
unlink(jsonPath)
+unlink(jsonPathNa)
diff --git a/README.md b/README.md
index 9c09d40e2bdae..380422ca00dbe 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,8 @@
Spark is a fast and general cluster computing system for Big Data. It provides
high-level APIs in Scala, Java, and Python, and an optimized engine that
supports general computation graphs for data analysis. It also supports a
-rich set of higher-level tools including Spark SQL for SQL and structured
-data processing, MLlib for machine learning, GraphX for graph processing,
+rich set of higher-level tools including Spark SQL for SQL and DataFrames,
+MLlib for machine learning, GraphX for graph processing,
and Spark Streaming for stream processing.
@@ -22,7 +22,7 @@ This README file only contains basic setup instructions.
Spark is built using [Apache Maven](http://maven.apache.org/).
To build Spark and its example programs, run:
- mvn -DskipTests clean package
+ build/mvn -DskipTests clean package
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
@@ -43,7 +43,7 @@ Try the following command, which should return 1000:
Alternatively, if you prefer Python, you can use the Python shell:
./bin/pyspark
-
+
And run the following command, which should also return 1000:
>>> sc.parallelize(range(1000)).count()
@@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example:
will run the Pi example locally.
You can set the MASTER environment variable when running examples to submit
-examples to a cluster. This can be a mesos:// or spark:// URL,
-"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run
-locally with one thread, or "local[N]" to run locally with N threads. You
+examples to a cluster. This can be a mesos:// or spark:// URL,
+"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run
+locally with one thread, or "local[N]" to run locally with N threads. You
can also use an abbreviated class name if the class is in the `examples`
package. For instance:
@@ -75,7 +75,7 @@ can be run using:
./dev/run-tests
-Please see the guidance on how to
+Please see the guidance on how to
[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools).
## A Note About Hadoop Versions
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 626c8577e31fe..e9c6d26ccddc7 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 1f3dec91314f2..ed5c37e595a96 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.scalacheckscalacheck_${scala.binary.version}
diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
index ccb262a4ee02a..fb10d734ac74b 100644
--- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.bagel
-import org.scalatest.{BeforeAndAfter, FunSuite, Assertions}
+import org.scalatest.{BeforeAndAfter, Assertions}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
@@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
-class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
+class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts {
var sc: SparkContext = _
diff --git a/bin/pyspark b/bin/pyspark
index 8acad6113797d..f9dbddfa53560 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -17,24 +17,10 @@
# limitations under the License.
#
-# Figure out where Spark is installed
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
source "$SPARK_HOME"/bin/load-spark-env.sh
-
-function usage() {
- if [ -n "$1" ]; then
- echo $1
- fi
- echo "Usage: ./bin/pyspark [options]" 1>&2
- "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- exit $2
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage
-fi
+export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"
# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython`
# executable, while the worker would still be launched using PYSPARK_PYTHON.
@@ -90,11 +76,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
- if [[ -n "$PYSPARK_DOC_TEST" ]]; then
- exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
- else
- exec "$PYSPARK_DRIVER_PYTHON" $1
- fi
+ exec "$PYSPARK_DRIVER_PYTHON" -m $1
exit
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index 09b4149c2a439..45e9e3def5121 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -21,6 +21,7 @@ rem Figure out where the Spark framework is installed
set SPARK_HOME=%~dp0..
call %SPARK_HOME%\bin\load-spark-env.cmd
+set _SPARK_CMD_USAGE=Usage: bin\pyspark.cmd [options]
rem Figure out which Python to use.
if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
diff --git a/bin/spark-class b/bin/spark-class
index c49d97ce5cf25..2b59e5df5736f 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -16,18 +16,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-set -e
# Figure out where Spark is installed
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
. "$SPARK_HOME"/bin/load-spark-env.sh
-if [ -z "$1" ]; then
- echo "Usage: spark-class []" 1>&2
- exit 1
-fi
-
# Find the java binary
if [ -n "${JAVA_HOME}" ]; then
RUNNER="${JAVA_HOME}/bin/java"
@@ -64,24 +58,6 @@ fi
SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}"
-# Verify that versions of java used to build the jars and run Spark are compatible
-if [ -n "$JAVA_HOME" ]; then
- JAR_CMD="$JAVA_HOME/bin/jar"
-else
- JAR_CMD="jar"
-fi
-
-if [ $(command -v "$JAR_CMD") ] ; then
- jar_error_check=$("$JAR_CMD" -tf "$SPARK_ASSEMBLY_JAR" nonexistent/class/path 2>&1)
- if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then
- echo "Loading Spark jar with '$JAR_CMD' failed. " 1>&2
- echo "This is likely because Spark was compiled with Java 7 and run " 1>&2
- echo "with Java 6. (see SPARK-1703). Please use Java 7 to run Spark " 1>&2
- echo "or build Spark with Java 6." 1>&2
- exit 1
- fi
-fi
-
LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR"
# Add the launcher build dir to the classpath if requested.
@@ -98,9 +74,4 @@ CMD=()
while IFS= read -d '' -r ARG; do
CMD+=("$ARG")
done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@")
-
-if [ "${CMD[0]}" = "usage" ]; then
- "${CMD[@]}"
-else
- exec "${CMD[@]}"
-fi
+exec "${CMD[@]}"
diff --git a/bin/spark-shell b/bin/spark-shell
index b3761b5e1375b..a6dc863d83fc6 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -29,20 +29,7 @@ esac
set -o posix
export FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
-
-usage() {
- if [ -n "$1" ]; then
- echo "$1"
- fi
- echo "Usage: ./bin/spark-shell [options]"
- "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- exit "$2"
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage "" 0
-fi
+export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]"
# SPARK-4161: scala does not assume use of the java classpath,
# so we need to add the "-Dscala.usejavacp=true" flag manually. We
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
index 00fd30fa38d36..251309d67f860 100644
--- a/bin/spark-shell2.cmd
+++ b/bin/spark-shell2.cmd
@@ -18,12 +18,7 @@ rem limitations under the License.
rem
set SPARK_HOME=%~dp0..
-
-echo "%*" | findstr " \<--help\> \<-h\>" >nul
-if %ERRORLEVEL% equ 0 (
- call :usage
- exit /b 0
-)
+set _SPARK_CMD_USAGE=Usage: .\bin\spark-shell.cmd [options]
rem SPARK-4161: scala does not assume use of the java classpath,
rem so we need to add the "-Dscala.usejavacp=true" flag manually. We
@@ -37,16 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" (
set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true"
:run_shell
-call %SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
-set SPARK_ERROR_LEVEL=%ERRORLEVEL%
-if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" (
- call :usage
- exit /b 1
-)
-exit /b %SPARK_ERROR_LEVEL%
-
-:usage
-echo %SPARK_LAUNCHER_USAGE_ERROR%
-echo "Usage: .\bin\spark-shell.cmd [options]" >&2
-call %SPARK_HOME%\bin\spark-submit2.cmd --help 2>&1 | findstr /V "Usage" 1>&2
-goto :eof
+%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %*
diff --git a/bin/spark-sql b/bin/spark-sql
index ca1729f4cfcb4..4ea7bc6e39c07 100755
--- a/bin/spark-sql
+++ b/bin/spark-sql
@@ -17,41 +17,6 @@
# limitations under the License.
#
-#
-# Shell script for starting the Spark SQL CLI
-
-# Enter posix mode for bash
-set -o posix
-
-# NOTE: This exact class name is matched downstream by SparkSubmit.
-# Any changes need to be reflected there.
-export CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
-
-# Figure out where Spark is installed
export FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
-
-function usage {
- if [ -n "$1" ]; then
- echo "$1"
- fi
- echo "Usage: ./bin/spark-sql [options] [cli option]"
- pattern="usage"
- pattern+="\|Spark assembly has been built with Hive"
- pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set"
- pattern+="\|Spark Command: "
- pattern+="\|--help"
- pattern+="\|======="
-
- "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- echo
- echo "CLI options:"
- "$FWDIR"/bin/spark-class "$CLASS" --help 2>&1 | grep -v "$pattern" 1>&2
- exit "$2"
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage "" 0
-fi
-
-exec "$FWDIR"/bin/spark-submit --class "$CLASS" "$@"
+export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]"
+exec "$FWDIR"/bin/spark-submit --class org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver "$@"
diff --git a/bin/spark-submit b/bin/spark-submit
index 0e0afe71a0f05..255378b0f077c 100755
--- a/bin/spark-submit
+++ b/bin/spark-submit
@@ -22,16 +22,4 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
# disable randomized hash for string in Python 3.3+
export PYTHONHASHSEED=0
-# Only define a usage function if an upstream script hasn't done so.
-if ! type -t usage >/dev/null 2>&1; then
- usage() {
- if [ -n "$1" ]; then
- echo "$1"
- fi
- "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit --help
- exit "$2"
- }
- export -f usage
-fi
-
exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@"
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
index d3fc4a5cc3f6e..651376e526928 100644
--- a/bin/spark-submit2.cmd
+++ b/bin/spark-submit2.cmd
@@ -24,15 +24,4 @@ rem disable randomized hash for string in Python 3.3+
set PYTHONHASHSEED=0
set CLASS=org.apache.spark.deploy.SparkSubmit
-call %~dp0spark-class2.cmd %CLASS% %*
-set SPARK_ERROR_LEVEL=%ERRORLEVEL%
-if not "x%SPARK_LAUNCHER_USAGE_ERROR%"=="x" (
- call :usage
- exit /b 1
-)
-exit /b %SPARK_ERROR_LEVEL%
-
-:usage
-echo %SPARK_LAUNCHER_USAGE_ERROR%
-call %SPARK_HOME%\bin\spark-class2.cmd %CLASS% --help
-goto :eof
+%~dp0spark-class2.cmd %CLASS% %*
diff --git a/bin/sparkR b/bin/sparkR
index 8c918e2b09aef..464c29f369424 100755
--- a/bin/sparkR
+++ b/bin/sparkR
@@ -17,23 +17,7 @@
# limitations under the License.
#
-# Figure out where Spark is installed
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
-
source "$SPARK_HOME"/bin/load-spark-env.sh
-
-function usage() {
- if [ -n "$1" ]; then
- echo $1
- fi
- echo "Usage: ./bin/sparkR [options]" 1>&2
- "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
- exit $2
-}
-export -f usage
-
-if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
- usage
-fi
-
+export _SPARK_CMD_USAGE="Usage: ./bin/sparkR [options]"
exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@"
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index 7de0011a48ca8..7f17bc7eea4f5 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -4,7 +4,7 @@
# divided into instances which correspond to internal components.
# Each instance can be configured to report its metrics to one or more sinks.
# Accepted values for [instance] are "master", "worker", "executor", "driver",
-# and "applications". A wild card "*" can be used as an instance name, in
+# and "applications". A wildcard "*" can be used as an instance name, in
# which case all instances will inherit the supplied property.
#
# Within an instance, a "source" specifies a particular set of grouped metrics.
@@ -32,7 +32,7 @@
# name (see examples below).
# 2. Some sinks involve a polling period. The minimum allowed polling period
# is 1 second.
-# 3. Wild card properties can be overridden by more specific properties.
+# 3. Wildcard properties can be overridden by more specific properties.
# For example, master.sink.console.period takes precedence over
# *.sink.console.period.
# 4. A metrics specific configuration
@@ -47,6 +47,13 @@
# instance master and applications. MetricsServlet may not be configured by self.
#
+## List of available common sources and their properties.
+
+# org.apache.spark.metrics.source.JvmSource
+# Note: Currently, JvmSource is the only available common source
+# to add additionaly to an instance, to enable this,
+# set the "class" option to its fully qulified class name (see examples below)
+
## List of available sinks and their properties.
# org.apache.spark.metrics.sink.ConsoleSink
diff --git a/core/pom.xml b/core/pom.xml
index bfa49d0d6dc25..40a64beccdc24 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -338,6 +338,12 @@
org.seleniumhq.seleniumselenium-java
+
+
+ com.google.guava
+ guava
+
+ test
@@ -377,9 +383,15 @@
test
- org.spark-project
+ net.razorvinepyrolite4.4
+
+
+ net.razorvine
+ serpent
+
+ net.sf.py4j
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
new file mode 100644
index 0000000000000..d3d6280284beb
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
+import org.apache.spark.util.Utils;
+
+/**
+ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
+ * writes incoming records to separate files, one file per reduce partition, then concatenates these
+ * per-partition files to form a single output file, regions of which are served to reducers.
+ * Records are not buffered in memory. This is essentially identical to
+ * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format
+ * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}.
+ *
+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it
+ * simultaneously opens separate serializers and file streams for all partitions. As a result,
+ * {@link SortShuffleManager} only selects this write path when
+ *
+ *
no Ordering is specified,
+ *
no Aggregator is specific, and
+ *
the number of partitions is less than
+ * spark.shuffle.sort.bypassMergeThreshold.
+ *
+ *
+ * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was
+ * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
+ *
+ * There have been proposals to completely remove this code path; see SPARK-6026 for details.
+ */
+final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter {
+
+ private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
+
+ private final int fileBufferSize;
+ private final boolean transferToEnabled;
+ private final int numPartitions;
+ private final BlockManager blockManager;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final Serializer serializer;
+
+ /** Array of file writers, one for each partition */
+ private BlockObjectWriter[] partitionWriters;
+
+ public BypassMergeSortShuffleWriter(
+ SparkConf conf,
+ BlockManager blockManager,
+ Partitioner partitioner,
+ ShuffleWriteMetrics writeMetrics,
+ Serializer serializer) {
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+ this.numPartitions = partitioner.numPartitions();
+ this.blockManager = blockManager;
+ this.partitioner = partitioner;
+ this.writeMetrics = writeMetrics;
+ this.serializer = serializer;
+ }
+
+ @Override
+ public void insertAll(Iterator> records) throws IOException {
+ assert (partitionWriters == null);
+ if (!records.hasNext()) {
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new BlockObjectWriter[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2 tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open();
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime);
+
+ while (records.hasNext()) {
+ final Product2 record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+ }
+
+ for (BlockObjectWriter writer : partitionWriters) {
+ writer.commitAndClose();
+ }
+ }
+
+ @Override
+ public long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException {
+ // Track location of the partition starts in the output file
+ final long[] lengths = new long[numPartitions];
+ if (partitionWriters == null) {
+ // We were passed an empty iterator
+ return lengths;
+ }
+
+ final FileOutputStream out = new FileOutputStream(outputFile, true);
+ final long writeStartTime = System.nanoTime();
+ boolean threwException = true;
+ try {
+ for (int i = 0; i < numPartitions; i++) {
+ final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file());
+ boolean copyThrewException = true;
+ try {
+ lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) {
+ logger.error("Unable to delete file for partition {}", i);
+ }
+ }
+ threwException = false;
+ } finally {
+ Closeables.close(out, threwException);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+ }
+ partitionWriters = null;
+ return lengths;
+ }
+
+ @Override
+ public void stop() throws IOException {
+ if (partitionWriters != null) {
+ try {
+ final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
+ for (BlockObjectWriter writer : partitionWriters) {
+ // This method explicitly does _not_ throw exceptions:
+ writer.revertPartialWritesAndClose();
+ if (!diskBlockManager.getFile(writer.blockId()).delete()) {
+ logger.error("Error while deleting file for block {}", writer.blockId());
+ }
+ }
+ } finally {
+ partitionWriters = null;
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
new file mode 100644
index 0000000000000..656ea0401a144
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.collection.Iterator;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.TaskContext;
+import org.apache.spark.storage.BlockId;
+
+/**
+ * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
+ */
+@Private
+public interface SortShuffleFileWriter {
+
+ void insertAll(Iterator> records) throws IOException;
+
+ /**
+ * Write all the data added into this shuffle sorter into a file in the disk store. This is
+ * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+ * binary files if we decided to avoid merge-sorting.
+ *
+ * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+ * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+ * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+ */
+ long[] writePartitionedFile(
+ BlockId blockId,
+ TaskContext context,
+ File outputFile) throws IOException;
+
+ void stop() throws IOException;
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
index 013db8df9b363..0b450dc76bc38 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -50,4 +50,9 @@ $(function() {
$("span.additional-metric-title").click(function() {
$(this).parent().find('input[type="checkbox"]').trigger('click');
});
+
+ // Trigger a double click on the span to show full job description.
+ $(".description-input").dblclick(function() {
+ $(this).removeClass("description-input").addClass("description-input-full");
+ });
});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
index dbacbf19beee5..dde6069000bc4 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
@@ -100,7 +100,7 @@ sorttable = {
this.removeChild(document.getElementById('sorttable_sortfwdind'));
sortrevind = document.createElement('span');
sortrevind.id = "sorttable_sortrevind";
- sortrevind.innerHTML = stIsIE ? ' 5' : ' ▴';
+ sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾';
this.appendChild(sortrevind);
return;
}
@@ -113,7 +113,7 @@ sorttable = {
this.removeChild(document.getElementById('sorttable_sortrevind'));
sortfwdind = document.createElement('span');
sortfwdind.id = "sorttable_sortfwdind";
- sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾';
+ sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴';
this.appendChild(sortfwdind);
return;
}
@@ -134,7 +134,7 @@ sorttable = {
this.className += ' sorttable_sorted';
sortfwdind = document.createElement('span');
sortfwdind.id = "sorttable_sortfwdind";
- sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾';
+ sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴';
this.appendChild(sortfwdind);
// build an array to sort. This is a Schwartzian transform thing,
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index aaeba5b1027c9..7a0dec2a3eaec 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -193,7 +193,7 @@ function renderDagVizForJob(svgContainer) {
// Use the link from the stage table so it also works for the history server
var attemptId = 0
var stageLink = d3.select("#stage-" + stageId + "-" + attemptId)
- .select("a")
+ .select("a.name-link")
.attr("href") + "&expandDagViz=true";
container = svgContainer
.append("a")
@@ -235,7 +235,7 @@ function renderDagVizForJob(svgContainer) {
// them separately later. Note that we cannot draw them now because we need to
// put these edges in a separate container that is on top of all stage graphs.
metadata.selectAll(".incoming-edge").each(function(v) {
- var edge = d3.select(this).text().split(","); // e.g. 3,4 => [3, 4]
+ var edge = d3.select(this).text().trim().split(","); // e.g. 3,4 => [3, 4]
crossStageEdges.push(edge);
});
});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
index 604c29994145a..ca74ef9d7e94e 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
@@ -46,7 +46,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) {
};
$(this).click(function() {
- var jobPagePath = $(getSelectorForJobEntry(this)).find("a").attr("href")
+ var jobPagePath = $(getSelectorForJobEntry(this)).find("a.name-link").attr("href")
window.location.href = jobPagePath
});
@@ -105,7 +105,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) {
};
$(this).click(function() {
- var stagePagePath = $(getSelectorForStageEntry(this)).find("a").attr("href")
+ var stagePagePath = $(getSelectorForStageEntry(this)).find("a.name-link").attr("href")
window.location.href = stagePagePath
});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index e7c1d475d4e52..b1cef47042247 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -135,6 +135,14 @@ pre {
display: block;
}
+.description-input-full {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ width: 100%;
+ white-space: normal;
+ display: block;
+}
+
.stacktrace-details {
max-height: 300px;
overflow-y: auto;
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 330df1d59a9b1..5a8d17bd99933 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -228,7 +228,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
* @tparam T result type
*/
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
- extends Accumulable[T,T](initialValue, param, name) {
+ extends Accumulable[T, T](initialValue, param, name) {
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index af9765d313e9e..ceeb58075d345 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -34,8 +34,8 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
- // When spilling is enabled sorting will happen externally, but not necessarily with an
- // ExternalSorter.
+ // When spilling is enabled sorting will happen externally, but not necessarily with an
+ // ExternalSorter.
private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
@deprecated("use combineValuesByKey with TaskContext argument", "0.9.0")
@@ -45,7 +45,7 @@ case class Aggregator[K, V, C] (
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
if (!isSpillEnabled) {
- val combiners = new AppendOnlyMap[K,C]
+ val combiners = new AppendOnlyMap[K, C]
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
@@ -76,7 +76,7 @@ case class Aggregator[K, V, C] (
: Iterator[(K, C)] =
{
if (!isSpillEnabled) {
- val combiners = new AppendOnlyMap[K,C]
+ val combiners = new AppendOnlyMap[K, C]
var kc: Product2[K, C] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 9514604752640..49329423dca76 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -101,6 +101,9 @@ private[spark] class ExecutorAllocationManager(
private val executorIdleTimeoutS = conf.getTimeAsSeconds(
"spark.dynamicAllocation.executorIdleTimeout", "60s")
+ private val cachedExecutorIdleTimeoutS = conf.getTimeAsSeconds(
+ "spark.dynamicAllocation.cachedExecutorIdleTimeout", s"${2 * executorIdleTimeoutS}s")
+
// During testing, the methods to actually kill and add executors are mocked out
private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
@@ -150,6 +153,13 @@ private[spark] class ExecutorAllocationManager(
// Metric source for ExecutorAllocationManager to expose internal status to MetricsSystem.
val executorAllocationManagerSource = new ExecutorAllocationManagerSource
+ // Whether we are still waiting for the initial set of executors to be allocated.
+ // While this is true, we will not cancel outstanding executor requests. This is
+ // set to false when:
+ // (1) a stage is submitted, or
+ // (2) an executor idle timeout has elapsed.
+ @volatile private var initializing: Boolean = true
+
/**
* Verify that the settings specified through the config are valid.
* If not, throw an appropriate exception.
@@ -240,6 +250,7 @@ private[spark] class ExecutorAllocationManager(
removeTimes.retain { case (executorId, expireTime) =>
val expired = now >= expireTime
if (expired) {
+ initializing = false
removeExecutor(executorId)
}
!expired
@@ -261,15 +272,23 @@ private[spark] class ExecutorAllocationManager(
private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized {
val maxNeeded = maxNumExecutorsNeeded
- if (maxNeeded < numExecutorsTarget) {
+ if (initializing) {
+ // Do not change our target while we are still initializing,
+ // Otherwise the first job may have to ramp up unnecessarily
+ 0
+ } else if (maxNeeded < numExecutorsTarget) {
// The target number exceeds the number we actually need, so stop adding new
// executors and inform the cluster manager to cancel the extra pending requests
val oldNumExecutorsTarget = numExecutorsTarget
numExecutorsTarget = math.max(maxNeeded, minNumExecutors)
- client.requestTotalExecutors(numExecutorsTarget)
numExecutorsToAdd = 1
- logInfo(s"Lowering target number of executors to $numExecutorsTarget because " +
- s"not all requests are actually needed (previously $oldNumExecutorsTarget)")
+
+ // If the new target has not changed, avoid sending a message to the cluster manager
+ if (numExecutorsTarget < oldNumExecutorsTarget) {
+ client.requestTotalExecutors(numExecutorsTarget)
+ logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " +
+ s"$oldNumExecutorsTarget) because not all requested executors are actually needed")
+ }
numExecutorsTarget - oldNumExecutorsTarget
} else if (addTime != NOT_SET && now >= addTime) {
val delta = addExecutors(maxNeeded)
@@ -443,9 +462,23 @@ private[spark] class ExecutorAllocationManager(
private def onExecutorIdle(executorId: String): Unit = synchronized {
if (executorIds.contains(executorId)) {
if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ // Note that it is not necessary to query the executors since all the cached
+ // blocks we are concerned with are reported to the driver. Note that this
+ // does not include broadcast blocks.
+ val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId)
+ val now = clock.getTimeMillis()
+ val timeout = {
+ if (hasCachedBlocks) {
+ // Use a different timeout if the executor has cached blocks.
+ now + cachedExecutorIdleTimeoutS * 1000
+ } else {
+ now + executorIdleTimeoutS * 1000
+ }
+ }
+ val realTimeout = if (timeout <= 0) Long.MaxValue else timeout // overflow
+ removeTimes(executorId) = realTimeout
logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
- s"scheduled to run on the executor (to expire in $executorIdleTimeoutS seconds)")
- removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeoutS * 1000
+ s"scheduled to run on the executor (to expire in ${(realTimeout - now)/1000} seconds)")
}
} else {
logWarning(s"Attempted to mark unknown executor $executorId idle")
@@ -477,6 +510,7 @@ private[spark] class ExecutorAllocationManager(
private var numRunningTasks: Int = _
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ initializing = false
val stageId = stageSubmitted.stageInfo.stageId
val numTasks = stageSubmitted.stageInfo.numTasks
allocationManager.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 91f9ef8ce7185..48792a958130c 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
override def isCompleted: Boolean = jobWaiter.jobFinished
-
+
override def isCancelled: Boolean = _cancelled
override def value: Option[Try[T]] = {
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index f2b024ff6cb67..6909015ff66e6 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -29,7 +29,7 @@ import org.apache.spark.util.{ThreadUtils, Utils}
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
- * components to convey liveness or execution information for in-progress tasks. It will also
+ * components to convey liveness or execution information for in-progress tasks. It will also
* expire the hosts that have not heartbeated for more than spark.network.timeout.
*/
private[spark] case class Heartbeat(
@@ -43,8 +43,8 @@ private[spark] case class Heartbeat(
*/
private[spark] case object TaskSchedulerIsSet
-private[spark] case object ExpireDeadHosts
-
+private[spark] case object ExpireDeadHosts
+
private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
/**
@@ -62,18 +62,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
// "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses
// "milliseconds"
- private val slaveTimeoutMs =
+ private val slaveTimeoutMs =
sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s")
- private val executorTimeoutMs =
+ private val executorTimeoutMs =
sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000
-
+
// "spark.network.timeoutInterval" uses "seconds", while
// "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds"
- private val timeoutIntervalMs =
+ private val timeoutIntervalMs =
sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s")
- private val checkTimeoutIntervalMs =
+ private val checkTimeoutIntervalMs =
sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000
-
+
private var timeoutCheckingTask: ScheduledFuture[_] = null
// "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not
@@ -140,7 +140,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
}
}
}
-
+
override def onStop(): Unit = {
if (timeoutCheckingTask != null) {
timeoutCheckingTask.cancel(true)
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index 7e706bcc42f04..7cf7bc0dc6810 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -50,8 +50,8 @@ private[spark] class HttpFileServer(
def stop() {
httpServer.stop()
-
- // If we only stop sc, but the driver process still run as a services then we need to delete
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
// the tmp dir, if not, it will create too many tmp dirs
try {
Utils.deleteRecursively(baseDir)
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index b8d244408bc5b..82889bcd30988 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -103,7 +103,7 @@ class HashPartitioner(partitions: Int) extends Partitioner {
*/
class RangePartitioner[K : Ordering : ClassTag, V](
@transient partitions: Int,
- @transient rdd: RDD[_ <: Product2[K,V]],
+ @transient rdd: RDD[_ <: Product2[K, V]],
private var ascending: Boolean = true)
extends Partitioner {
@@ -185,7 +185,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
override def equals(other: Any): Boolean = other match {
- case r: RangePartitioner[_,_] =>
+ case r: RangePartitioner[_, _] =>
r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
case _ =>
false
@@ -249,7 +249,7 @@ private[spark] object RangePartitioner {
* @param sampleSizePerPartition max sample size per partition
* @return (total number of items, an array of (partitionId, number of items, sample))
*/
- def sketch[K:ClassTag](
+ def sketch[K : ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
val shift = rdd.id
@@ -272,7 +272,7 @@ private[spark] object RangePartitioner {
* @param partitions number of partitions
* @return selected bounds
*/
- def determineBounds[K:Ordering:ClassTag](
+ def determineBounds[K : Ordering : ClassTag](
candidates: ArrayBuffer[(K, Float)],
partitions: Int): Array[K] = {
val ordering = implicitly[Ordering[K]]
diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala
deleted file mode 100644
index 54fc3a856adfa..0000000000000
--- a/core/src/main/scala/org/apache/spark/SizeEstimator.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.apache.spark.annotation.DeveloperApi
-
-/**
- * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in
- * memory-aware caches.
- *
- * Based on the following JavaWorld article:
- * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
- */
-@DeveloperApi
-object SizeEstimator {
- /**
- * :: DeveloperApi ::
- * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate
- * includes space taken up by objects referenced by the given object, their references, and so on
- * and so forth.
- *
- * This is useful for determining the amount of heap space a broadcast variable will occupy on
- * each executor or the amount of space each object will take when caching objects in
- * deserialized form. This is not the same as the serialized size of the object, which will
- * typically be much smaller.
- */
- @DeveloperApi
- def estimate(obj: AnyRef): Long = org.apache.spark.util.SizeEstimator.estimate(obj)
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index b5e5d6f1465f3..46d72841dccce 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsBytes(key: String, defaultValue: String): Long = {
Utils.byteStringAsBytes(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Kibibytes are assumed.
@@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsKb(key: String, defaultValue: String): Long = {
Utils.byteStringAsKb(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Mebibytes are assumed.
@@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsMb(key: String, defaultValue: String): Long = {
Utils.byteStringAsMb(get(key, defaultValue))
}
-
+
/**
* Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Gibibytes are assumed.
@@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getSizeAsGb(key: String, defaultValue: String): Long = {
Utils.byteStringAsGb(get(key, defaultValue))
}
-
+
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
Option(settings.get(key)).orElse(getDeprecatedConfig(key, this))
@@ -480,8 +480,8 @@ private[spark] object SparkConf extends Logging {
"spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " +
"are no longer accepted. To specify the equivalent now, one may use '64k'.")
)
-
- Map(configs.map { cfg => (cfg.key -> cfg) }:_*)
+
+ Map(configs.map { cfg => (cfg.key -> cfg) } : _*)
}
/**
@@ -508,7 +508,7 @@ private[spark] object SparkConf extends Logging {
"spark.reducer.maxSizeInFlight" -> Seq(
AlternateConfig("spark.reducer.maxMbInFlight", "1.4")),
"spark.kryoserializer.buffer" ->
- Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4",
+ Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4",
translation = s => s"${(s.toDouble * 1000).toInt}k")),
"spark.kryoserializer.buffer.max" -> Seq(
AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")),
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index ea6c0dea08e47..a453c9bf4864a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -389,7 +389,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)
- _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten
+ _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten
_files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0))
.toSeq.flatten
@@ -438,7 +438,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_ui =
if (conf.getBoolean("spark.ui.enabled", true)) {
Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener,
- _env.securityManager,appName, startTime = startTime))
+ _env.securityManager, appName, startTime = startTime))
} else {
// For tests, do not enable the UI
None
@@ -917,7 +917,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
classOf[FixedLengthBinaryInputFormat],
classOf[LongWritable],
classOf[BytesWritable],
- conf=conf)
+ conf = conf)
val data = br.map { case (k, v) =>
val bytes = v.getBytes
assert(bytes.length == recordLength, "Byte array does not have correct length")
@@ -1267,7 +1267,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
(initialValue: R): Accumulable[R, T] = {
- val param = new GrowableAccumulableParam[R,T]
+ val param = new GrowableAccumulableParam[R, T]
val acc = new Accumulable(initialValue, param)
cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
@@ -1316,7 +1316,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val uri = new URI(path)
val schemeCorrectedPath = uri.getScheme match {
case null | "local" => new File(path).getCanonicalFile.toURI.toString
- case _ => path
+ case _ => path
}
val hadoopPath = new Path(schemeCorrectedPath)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 327114542880d..a185954089528 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -298,7 +298,7 @@ object SparkEnv extends Logging {
}
}
- val mapOutputTracker = if (isDriver) {
+ val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf)
} else {
new MapOutputTrackerWorker(conf)
@@ -348,7 +348,7 @@ object SparkEnv extends Logging {
val fileServerPort = conf.getInt("spark.fileserver.port", 0)
val server = new HttpFileServer(conf, securityManager, fileServerPort)
server.initialize()
- conf.set("spark.fileserver.uri", server.serverUri)
+ conf.set("spark.fileserver.uri", server.serverUri)
server
} else {
null
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 2ec42d3aea169..59ac82ccec53b 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -50,8 +50,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
private var jID: SerializableWritable[JobID] = null
private var taID: SerializableWritable[TaskAttemptID] = null
- @transient private var writer: RecordWriter[AnyRef,AnyRef] = null
- @transient private var format: OutputFormat[AnyRef,AnyRef] = null
+ @transient private var writer: RecordWriter[AnyRef, AnyRef] = null
+ @transient private var format: OutputFormat[AnyRef, AnyRef] = null
@transient private var committer: OutputCommitter = null
@transient private var jobContext: JobContext = null
@transient private var taskContext: TaskAttemptContext = null
@@ -114,10 +114,10 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
// ********* Private Functions *********
- private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = {
+ private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = {
if (format == null) {
format = conf.value.getOutputFormat()
- .asInstanceOf[OutputFormat[AnyRef,AnyRef]]
+ .asInstanceOf[OutputFormat[AnyRef, AnyRef]]
}
format
}
@@ -138,7 +138,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
private def getTaskContext(): TaskAttemptContext = {
if (taskContext == null) {
- taskContext = newTaskAttemptContext(conf.value, taID.value)
+ taskContext = newTaskAttemptContext(conf.value, taID.value)
}
taskContext
}
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index fe6320b504e15..a1ebbecf93b7b 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -51,7 +51,7 @@ private[spark] object TestUtils {
classpathUrls: Seq[URL] = Seq()): URL = {
val tempDir = Utils.createTempDir()
val files1 = for (name <- classNames) yield {
- createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
+ createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
}
val files2 = for ((childName, baseName) <- classNamesWithBase) yield {
createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 61af867b11b9c..a650df605b92e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double])
*/
def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
*/
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index db4e996feb31c..ed312770ee131 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
/**
* Return a sampled subset of this RDD.
- *
+ *
* @param withReplacement can elements be sampled multiple times (replaced when sampled out)
* @param fraction expected size of the sample as a fraction of this RDD's size
* without replacement: probability that each element is chosen; fraction must be [0, 1]
@@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
*/
def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
- *
+ *
* @param withReplacement can elements be sampled multiple times (replaced when sampled out)
* @param fraction expected size of the sample as a fraction of this RDD's size
* without replacement: probability that each element is chosen; fraction must be [0, 1]
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 74db7643224f5..c95615a5a9307 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
@deprecated("Use partitions() instead.", "1.1.0")
def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
-
+
/** Set of partitions in this RDD. */
def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
+ /** The partitioner of this RDD. */
+ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner)
+
/** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
@@ -96,7 +99,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def mapPartitionsWithIndex[R](
f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] =
- new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
+ new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))),
preservesPartitioning)(fakeClassTag))(fakeClassTag)
/**
@@ -492,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
- def takeSample(withReplacement: Boolean, num: Int): JList[T] =
+ def takeSample(withReplacement: Boolean, num: Int): JList[T] =
takeSample(withReplacement, num, Utils.random.nextLong)
-
+
def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 2d92f6a42b308..55a37f8c944b2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -723,7 +723,7 @@ private[spark] object PythonRDD extends Logging {
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new JavaToWritableConverter)
val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]]
- converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec)
+ converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec)
}
/**
@@ -797,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
- /**
+ /**
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
* by the DAGScheduler's single-threaded actor anyway.
- */
+ */
@transient var socket: Socket = _
def openSocket(): Socket = synchronized {
@@ -843,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
* An Wrapper for Python Broadcast, which is written into disk by Python. It also will
* write the data into disk after deserialization, then Python can read it from disks.
*/
+// scalastyle:off no.finalize
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable {
/**
@@ -884,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
}
}
}
+// scalastyle:on no.finalize
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 0a91977928cee..d24c650d37bb0 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -44,11 +44,11 @@ private[spark] class RBackend {
bossGroup = new NioEventLoopGroup(2)
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
-
+
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(classOf[NioServerSocketChannel])
-
+
bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
def initChannel(ch: SocketChannel): Unit = {
ch.pipeline()
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 0075d963711f1..2e86984c66b3a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend)
val reply = bos.toByteArray
ctx.write(reply)
}
-
+
override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
ctx.flush()
}
@@ -124,7 +124,7 @@ private[r] class RBackendHandler(server: RBackend)
}
throw new Exception(s"No matched method found for $cls.$methodName")
}
- val ret = methods.head.invoke(obj, args:_*)
+ val ret = methods.head.invoke(obj, args : _*)
// Write status bit
writeInt(dos, 0)
@@ -135,7 +135,7 @@ private[r] class RBackendHandler(server: RBackend)
matchMethod(numArgs, args, x.getParameterTypes)
}.head
- val obj = ctor.newInstance(args:_*)
+ val obj = ctor.newInstance(args : _*)
writeInt(dos, 0)
writeObject(dos, obj.asInstanceOf[AnyRef])
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 06247f7e8b78c..4dfa7325934ff 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -309,7 +309,7 @@ private class StringRRDD[T: ClassTag](
}
private object SpecialLengths {
- val TIMING_DATA = -1
+ val TIMING_DATA = -1
}
private[r] class BufferedStreamThread(
@@ -355,7 +355,6 @@ private[r] object RRDD {
val sparkConf = new SparkConf().setAppName(appName)
.setSparkHome(sparkHome)
- .setJars(jars)
// Override `master` if we have a user-specified value
if (master != "") {
@@ -373,7 +372,11 @@ private[r] object RRDD {
sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String])
}
- new JavaSparkContext(sparkConf)
+ val jsc = new JavaSparkContext(sparkConf)
+ jars.foreach { jar =>
+ jsc.addJar(jar)
+ }
+ jsc
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 371dfe454d1a2..f8e3f1a79082e 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -157,9 +157,11 @@ private[spark] object SerDe {
val keysLen = readInt(in)
val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
- val valuesType = readObjectType(in)
val valuesLen = readInt(in)
- val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType))
+ val values = (0 until valuesLen).map(_ => {
+ val valueType = readObjectType(in)
+ readTypedObject(in, valueType)
+ })
mapAsJavaMap(keys.zip(values).toMap)
} else {
new java.util.HashMap[Object, Object]()
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 4457c75e8b0fc..b69af639f7862 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -125,7 +125,7 @@ private[broadcast] object HttpBroadcast extends Logging {
securityManager = securityMgr
if (isDriver) {
createServer(conf)
- conf.set("spark.httpBroadcast.uri", serverUri)
+ conf.set("spark.httpBroadcast.uri", serverUri)
}
serverUri = conf.get("spark.httpBroadcast.uri")
cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
@@ -187,7 +187,7 @@ private[broadcast] object HttpBroadcast extends Logging {
}
private def read[T: ClassTag](id: Long): T = {
- logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
+ logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
var uc: URLConnection = null
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index c048b78910f38..b4edb6109e839 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -65,7 +65,7 @@ private object FaultToleranceTest extends App with Logging {
private val workers = ListBuffer[TestWorkerInfo]()
private var sc: SparkContext = _
- private val zk = SparkCuratorUtil.newClient(conf)
+ private val zk = SparkCuratorUtil.newClient(conf)
private var numPassed = 0
private var numFailed = 0
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 198371b70f14f..a0eae774268ed 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -82,13 +82,13 @@ object SparkSubmit {
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
// Exposed for testing
- private[spark] var exitFn: () => Unit = () => System.exit(1)
+ private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode)
private[spark] var printStream: PrintStream = System.err
private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str)
private[spark] def printErrorAndExit(str: String): Unit = {
printStream.println("Error: " + str)
printStream.println("Run with --help for usage help or --verbose for debug output")
- exitFn()
+ exitFn(1)
}
private[spark] def printVersionAndExit(): Unit = {
printStream.println("""Welcome to
@@ -99,7 +99,7 @@ object SparkSubmit {
/_/
""".format(SPARK_VERSION))
printStream.println("Type --help for more information.")
- exitFn()
+ exitFn(0)
}
def main(args: Array[String]): Unit = {
@@ -160,7 +160,7 @@ object SparkSubmit {
// detect exceptions with empty stack traces here, and treat them differently.
if (e.getStackTrace().length == 0) {
printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
- exitFn()
+ exitFn(1)
} else {
throw e
}
@@ -361,7 +361,7 @@ object SparkSubmit {
pyArchives = pythonPath.mkString(",")
}
- pyArchives = pyArchives.split(",").map { localPath=>
+ pyArchives = pyArchives.split(",").map { localPath =>
val localURI = Utils.resolveURI(localPath)
if (localURI.getScheme != "local") {
args.files = mergeFileLists(args.files, localURI.toString)
@@ -425,9 +425,10 @@ object SparkSubmit {
// Yarn client only
OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"),
OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"),
- OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"),
OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"),
OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"),
+ OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"),
+ OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"),
// Yarn cluster only
OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
@@ -440,13 +441,11 @@ object SparkSubmit {
OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"),
OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"),
-
- // Yarn client or cluster
- OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"),
- OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"),
+ OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"),
+ OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"),
// Other options
- OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES,
+ OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.cores"),
OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.memory"),
@@ -700,7 +699,7 @@ object SparkSubmit {
/**
* Return whether the given main class represents a sql shell.
*/
- private def isSqlShell(mainClass: String): Boolean = {
+ private[deploy] def isSqlShell(mainClass: String): Boolean = {
mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
}
@@ -869,7 +868,7 @@ private[spark] object SparkSubmitUtils {
md.addDependency(dd)
}
}
-
+
/** Add exclusion rules for dependencies already included in the spark-assembly */
def addExclusionRules(
ivySettings: IvySettings,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index c0e4c771908b3..b7429a901e162 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,12 +17,15 @@
package org.apache.spark.deploy
+import java.io.{ByteArrayOutputStream, PrintStream}
+import java.lang.reflect.InvocationTargetException
import java.net.URI
import java.util.{List => JList}
import java.util.jar.JarFile
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.io.Source
import org.apache.spark.deploy.SparkSubmitAction._
import org.apache.spark.launcher.SparkSubmitArgumentsParser
@@ -169,6 +172,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
numExecutors = Option(numExecutors)
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
+ keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull
+ principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && !isR && primaryResource != null) {
@@ -410,6 +415,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
case VERSION =>
SparkSubmit.printVersionAndExit()
+ case USAGE_ERROR =>
+ printUsageAndExit(1)
+
case _ =>
throw new IllegalArgumentException(s"Unexpected argument '$opt'.")
}
@@ -447,11 +455,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
}
- outStream.println(
+ val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse(
"""Usage: spark-submit [options] [app arguments]
|Usage: spark-submit --kill [submission ID] --master [spark://...]
- |Usage: spark-submit --status [submission ID] --master [spark://...]
- |
+ |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin)
+ outStream.println(command)
+
+ outStream.println(
+ """
|Options:
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
| --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or
@@ -523,6 +534,65 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| delegation tokens periodically.
""".stripMargin
)
- SparkSubmit.exitFn()
+
+ if (SparkSubmit.isSqlShell(mainClass)) {
+ outStream.println("CLI options:")
+ outStream.println(getSqlShellOptions())
+ }
+
+ SparkSubmit.exitFn(exitCode)
}
+
+ /**
+ * Run the Spark SQL CLI main class with the "--help" option and catch its output. Then filter
+ * the results to remove unwanted lines.
+ *
+ * Since the CLI will call `System.exit()`, we install a security manager to prevent that call
+ * from working, and restore the original one afterwards.
+ */
+ private def getSqlShellOptions(): String = {
+ val currentOut = System.out
+ val currentErr = System.err
+ val currentSm = System.getSecurityManager()
+ try {
+ val out = new ByteArrayOutputStream()
+ val stream = new PrintStream(out)
+ System.setOut(stream)
+ System.setErr(stream)
+
+ val sm = new SecurityManager() {
+ override def checkExit(status: Int): Unit = {
+ throw new SecurityException()
+ }
+
+ override def checkPermission(perm: java.security.Permission): Unit = {}
+ }
+ System.setSecurityManager(sm)
+
+ try {
+ Class.forName(mainClass).getMethod("main", classOf[Array[String]])
+ .invoke(null, Array(HELP))
+ } catch {
+ case e: InvocationTargetException =>
+ // Ignore SecurityException, since we throw it above.
+ if (!e.getCause().isInstanceOf[SecurityException]) {
+ throw e
+ }
+ }
+
+ stream.flush()
+
+ // Get the output and discard any unnecessary lines from it.
+ Source.fromString(new String(out.toByteArray())).getLines
+ .filter { line =>
+ !line.startsWith("log4j") && !line.startsWith("usage")
+ }
+ .mkString("\n")
+ } finally {
+ System.setSecurityManager(currentSm)
+ System.setOut(currentOut)
+ System.setErr(currentErr)
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index 298a8201960d1..5f5e0fe1c34d7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -17,6 +17,9 @@
package org.apache.spark.deploy.history
+import java.util.zip.ZipOutputStream
+
+import org.apache.spark.SparkException
import org.apache.spark.ui.SparkUI
private[spark] case class ApplicationAttemptInfo(
@@ -62,4 +65,12 @@ private[history] abstract class ApplicationHistoryProvider {
*/
def getConfig(): Map[String, String] = Map()
+ /**
+ * Writes out the event logs to the output stream provided. The logs will be compressed into a
+ * single zip file and written out.
+ * @throws SparkException if the logs for the app id cannot be found.
+ */
+ @throws(classOf[SparkException])
+ def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 45c2be34c8680..5427a88f32ffd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -17,16 +17,18 @@
package org.apache.spark.deploy.history
-import java.io.{BufferedInputStream, FileNotFoundException, IOException, InputStream}
+import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream}
import java.util.concurrent.{ExecutorService, Executors, TimeUnit}
+import java.util.zip.{ZipEntry, ZipOutputStream}
import scala.collection.mutable
+import com.google.common.io.ByteStreams
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.fs.permission.AccessControlException
-import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
import org.apache.spark.scheduler._
@@ -59,7 +61,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
.map { d => Utils.resolveURI(d).toString }
.getOrElse(DEFAULT_LOG_DIR)
- private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf))
+ private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+ private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf)
// Used by check event thread and clean log thread.
// Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs
@@ -219,6 +222,58 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
}
}
+ override def writeEventLogs(
+ appId: String,
+ attemptId: Option[String],
+ zipStream: ZipOutputStream): Unit = {
+
+ /**
+ * This method compresses the files passed in, and writes the compressed data out into the
+ * [[OutputStream]] passed in. Each file is written as a new [[ZipEntry]] with its name being
+ * the name of the file being compressed.
+ */
+ def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = {
+ val fs = FileSystem.get(hadoopConf)
+ val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer
+ try {
+ outputStream.putNextEntry(new ZipEntry(entryName))
+ ByteStreams.copy(inputStream, outputStream)
+ outputStream.closeEntry()
+ } finally {
+ inputStream.close()
+ }
+ }
+
+ applications.get(appId) match {
+ case Some(appInfo) =>
+ try {
+ // If no attempt is specified, or there is no attemptId for attempts, return all attempts
+ appInfo.attempts.filter { attempt =>
+ attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get
+ }.foreach { attempt =>
+ val logPath = new Path(logDir, attempt.logPath)
+ // If this is a legacy directory, then add the directory to the zipStream and add
+ // each file to that directory.
+ if (isLegacyLogDirectory(fs.getFileStatus(logPath))) {
+ val files = fs.listStatus(logPath)
+ zipStream.putNextEntry(new ZipEntry(attempt.logPath + "/"))
+ zipStream.closeEntry()
+ files.foreach { file =>
+ val path = file.getPath
+ zipFileToStream(path, attempt.logPath + Path.SEPARATOR + path.getName, zipStream)
+ }
+ } else {
+ zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream)
+ }
+ }
+ } finally {
+ zipStream.close()
+ }
+ case None => throw new SparkException(s"Logs for $appId not found.")
+ }
+ }
+
+
/**
* Replay the log files in the list and merge the list of old applications with new ones
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 5a0eb585a9049..10638afb74900 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.history
import java.util.NoSuchElementException
+import java.util.zip.ZipOutputStream
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import com.google.common.cache._
@@ -173,6 +174,13 @@ class HistoryServer(
getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo)
}
+ override def writeEventLogs(
+ appId: String,
+ attemptId: Option[String],
+ zipStream: ZipOutputStream): Unit = {
+ provider.writeEventLogs(appId, attemptId, zipStream)
+ }
+
/**
* Returns the provider configuration to show in the listing page.
*
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index a2a97a7877ce7..4692d22651c93 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -23,7 +23,7 @@ import org.apache.spark.util.Utils
/**
* Command-line parser for the master.
*/
-private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
+private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String])
extends Logging {
private var propertiesFile: String = null
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 80db6d474b5c1..328d95a7a0c68 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil
private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
extends PersistenceEngine
with Logging {
-
+
private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 756927682cd24..6a7c74020bace 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -75,6 +75,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory")
val workers = state.workers.sortBy(_.id)
+ val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
@@ -108,12 +109,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}.getOrElse { Seq.empty }
}
-
Workers: {state.workers.size}
-
Cores: {state.workers.map(_.cores).sum} Total,
- {state.workers.map(_.coresUsed).sum} Used
-
Memory:
- {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total,
- {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
+
Alive Workers: {aliveWorkers.size}
+
Cores in use: {aliveWorkers.map(_.cores).sum} Total,
+ {aliveWorkers.map(_.coresUsed).sum} Used
+
Memory in use:
+ {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total,
+ {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 4dbccb9e6e46c..2128fdffecc05 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -7,11 +7,7 @@ redirect_from: "building-with-maven.html"
* This will become a table of contents (this text will be scraped).
{:toc}
-Building Spark using Maven requires Maven 3.0.4 or newer and Java 6+.
-
-**Note:** Building Spark with Java 7 or later can create JAR files that may not be
-readable with early versions of Java 6, due to the large number of files in the JAR
-archive. Build with Java 6 if this is an issue for your deployment.
+Building Spark using Maven requires Maven 3.0.4 or newer and Java 7+.
# Building with `build/mvn`
@@ -80,6 +76,7 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro
2.2.x
hadoop-2.2
2.3.x
hadoop-2.3
2.4.x
hadoop-2.4
+
2.6.x and later 2.x
hadoop-2.6
@@ -118,14 +115,10 @@ mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests
# Building With Hive and JDBC Support
To enable Hive integration for Spark SQL along with its JDBC server and CLI,
add the `-Phive` and `Phive-thriftserver` profiles to your existing build options.
-By default Spark will build with Hive 0.13.1 bindings. You can also build for
-Hive 0.12.0 using the `-Phive-0.12.0` profile.
+By default Spark will build with Hive 0.13.1 bindings.
{% highlight bash %}
# Apache Hadoop 2.4.X with Hive 13 support
mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package
-
-# Apache Hadoop 2.4.X with Hive 12 support
-mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-0.12.0 -Phive-thriftserver -DskipTests clean package
{% endhighlight %}
# Building for Scala 2.11
@@ -134,9 +127,7 @@ To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` prop
dev/change-version-to-2.11.sh
mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package
-Scala 2.11 support in Spark does not support a few features due to dependencies
-which are themselves not Scala 2.11 ready. Specifically, Spark's external
-Kafka library and JDBC component are not yet supported in Scala 2.11 builds.
+Spark does not yet support its JDBC component for Scala 2.11.
# Spark Tests in Maven
@@ -180,7 +171,7 @@ Thus, the full flow for running continuous-compilation of the `core` submodule m
# Building Spark with IntelliJ IDEA or Eclipse
For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the
-[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-IDESetup).
+[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IDESetup).
# Running Java 8 Test Suites
diff --git a/docs/configuration.md b/docs/configuration.md
index 30508a617fdd8..3960e7e78bde1 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1,4 +1,4 @@
---
+---
layout: global
displayTitle: Spark Configuration
title: Configuration
@@ -618,7 +618,7 @@ Apart from these, the following properties are also available, and may be useful
spark.kryo.referenceTracking
-
true
+
true (false when using Spark SQL Thrift Server)
Whether to track references to the same object when serializing data with Kryo, which is
necessary if your object graphs have loops and useful for efficiency if they contain multiple
@@ -679,7 +679,10 @@ Apart from these, the following properties are also available, and may be useful
spark.serializer
-
org.apache.spark.serializer. JavaSerializer
+
+ org.apache.spark.serializer. JavaSerializer (org.apache.spark.serializer.
+ KryoSerializer when using Spark SQL Thrift Server)
+
Class to use for serializing objects that will be sent over the network or need to be cached
in serialized form. The default of Java serialization works with any Serializable Java object
@@ -1201,6 +1204,15 @@ Apart from these, the following properties are also available, and may be useful
description.
+
+
spark.dynamicAllocation.cachedExecutorIdleTimeout
+
2 * executorIdleTimeout
+
+ If dynamic allocation is enabled and an executor which has cached data blocks has been idle for more than this duration,
+ the executor will be removed. For more details, see this
+ description.
+
+
spark.dynamicAllocation.initialExecutors
spark.dynamicAllocation.minExecutors
diff --git a/docs/hadoop-provided.md b/docs/hadoop-provided.md
new file mode 100644
index 0000000000000..0ba5a58051abc
--- /dev/null
+++ b/docs/hadoop-provided.md
@@ -0,0 +1,26 @@
+---
+layout: global
+displayTitle: Using Spark's "Hadoop Free" Build
+title: Using Spark's "Hadoop Free" Build
+---
+
+Spark uses Hadoop client libraries for HDFS and YARN. Starting in version Spark 1.4, the project packages "Hadoop free" builds that lets you more easily connect a single Spark binary to any Hadoop version. To use these builds, you need to modify `SPARK_DIST_CLASSPATH` to include Hadoop's package jars. The most convenient place to do this is by adding an entry in `conf/spark-env.sh`.
+
+This page describes how to connect Spark to Hadoop for different types of distributions.
+
+# Apache Hadoop
+For Apache distributions, you can use Hadoop's 'classpath' command. For instance:
+
+{% highlight bash %}
+### in conf/spark-env.sh ###
+
+# If 'hadoop' binary is on your PATH
+export SPARK_DIST_CLASSPATH=$(hadoop classpath)
+
+# With explicit path to 'hadoop' binary
+export SPARK_DIST_CLASSPATH=$(/path/to/hadoop/bin/hadoop classpath)
+
+# Passing a Hadoop configuration directory
+export SPARK_DIST_CLASSPATH=$(hadoop classpath --config /path/to/configs)
+
+{% endhighlight %}
diff --git a/docs/index.md b/docs/index.md
index 5ef6d983c45a5..d85cf12defefd 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -12,15 +12,19 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog
# Downloading
-Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. The downloads page
-contains Spark packages for many popular HDFS versions. If you'd like to build Spark from
-scratch, visit [Building Spark](building-spark.html).
+Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark uses Hadoop's client libraries for HDFS and YARN. Downloads are pre-packaged for a handful of popular Hadoop versions.
+Users can also download a "Hadoop free" binary and run Spark with any Hadoop version
+[by augmenting Spark's classpath](hadoop-provided.html).
+
+If you'd like to build Spark from
+source, visit [Building Spark](building-spark.html).
+
Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy to run
locally on one machine --- all you need is to have `java` installed on your system `PATH`,
or the `JAVA_HOME` environment variable pointing to a Java installation.
-Spark runs on Java 6+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses
+Spark runs on Java 7+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses
Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version
({{site.SCALA_BINARY_VERSION}}.x).
@@ -54,7 +58,7 @@ Example applications are also provided in Python. For example,
./bin/spark-submit examples/src/main/python/pi.py 10
-Spark also provides an experimental R API since 1.4 (only DataFrames APIs included).
+Spark also provides an experimental [R API](sparkr.html) since 1.4 (only DataFrames APIs included).
To run Spark interactively in a R interpreter, use `bin/sparkR`:
./bin/sparkR --master local[2]
diff --git a/docs/ml-features.md b/docs/ml-features.md
index efe9b3b8edb6e..f88c0248c1a8a 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3):
+## StringIndexer
+
+`StringIndexer` encodes a string column of labels to a column of label indices.
+The indices are in `[0, numLabels)`, ordered by label frequencies.
+So the most frequent label gets index `0`.
+If the input column is numeric, we cast it to string and index the string values.
+
+**Examples**
+
+Assume that we have the following DataFrame with columns `id` and `category`:
+
+~~~~
+ id | category
+----|----------
+ 0 | a
+ 1 | b
+ 2 | c
+ 3 | a
+ 4 | a
+ 5 | c
+~~~~
+
+`category` is a string column with three labels: "a", "b", and "c".
+Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output
+column, we should get the following:
+
+~~~~
+ id | category | categoryIndex
+----|----------|---------------
+ 0 | a | 0.0
+ 1 | b | 2.0
+ 2 | c | 1.0
+ 3 | a | 0.0
+ 4 | a | 0.0
+ 5 | c | 1.0
+~~~~
+
+"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
+index `2`.
+
+
+
## OneHotEncoder
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
@@ -789,6 +905,294 @@ scaledData = scalerModel.transform(dataFrame)
+## Bucketizer
+
+`Bucketizer` transforms a column of continuous features to a column of feature buckets, where the buckets are specified by users. It takes a parameter:
+
+* `splits`: Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. Splits should be strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; Otherwise, values outside the splits specified will be treated as errors. Two examples of `splits` are `Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)` and `Array(0.0, 1.0, 2.0)`.
+
+Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potenial out of Bucketizer bounds exception.
+
+Note also that the splits that you provided have to be in strictly increasing order, i.e. `s0 < s1 < s2 < ... < sn`.
+
+More details can be found in the API docs for [Bucketizer](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer).
+
+The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
+
+
+
+{% highlight scala %}
+import org.apache.spark.ml.feature.Bucketizer
+import org.apache.spark.sql.DataFrame
+
+val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
+
+val data = Array(-0.5, -0.3, 0.0, 0.2)
+val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features")
+
+val bucketizer = new Bucketizer()
+ .setInputCol("features")
+ .setOutputCol("bucketedFeatures")
+ .setSplits(splits)
+
+// Transform original data into its bucket index.
+val bucketedData = bucketizer.transform(dataFrame)
+{% endhighlight %}
+
+
+## ElementwiseProduct
+
+ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector.
+
+`\[ \begin{pmatrix}
+v_1 \\
+\vdots \\
+v_N
+\end{pmatrix} \circ \begin{pmatrix}
+ w_1 \\
+ \vdots \\
+ w_N
+ \end{pmatrix}
+= \begin{pmatrix}
+ v_1 w_1 \\
+ \vdots \\
+ v_N w_N
+ \end{pmatrix}
+\]`
+
+[`ElementwiseProduct`](api/scala/index.html#org.apache.spark.ml.feature.ElementwiseProduct) takes the following parameter:
+
+* `scalingVec`: the transforming vector.
+
+This example below demonstrates how to transform vectors using a transforming vector value.
+
+
+
+{% highlight scala %}
+import org.apache.spark.ml.feature.ElementwiseProduct
+import org.apache.spark.mllib.linalg.Vectors
+
+// Create some vector data; also works for sparse vectors
+val dataFrame = sqlContext.createDataFrame(Seq(
+ ("a", Vectors.dense(1.0, 2.0, 3.0)),
+ ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector")
+
+val transformingVector = Vectors.dense(0.0, 1.0, 2.0)
+val transformer = new ElementwiseProduct()
+ .setScalingVec(transformingVector)
+ .setInputCol("vector")
+ .setOutputCol("transformedVector")
+
+// Batch transform the vectors to create new column:
+val transformedData = transformer.transform(dataFrame)
+
+{% endhighlight %}
+
+
+## VectorAssembler
+
+`VectorAssembler` is a transformer that combines a given list of columns into a single vector
+column.
+It is useful for combining raw features and features generated by different feature transformers
+into a single feature vector, in order to train ML models like logistic regression and decision
+trees.
+`VectorAssembler` accepts the following input column types: all numeric types, boolean type,
+and vector type.
+In each row, the values of the input columns will be concatenated into a vector in the specified
+order.
+
+**Examples**
+
+Assume that we have a DataFrame with the columns `id`, `hour`, `mobile`, `userFeatures`,
+and `clicked`:
+
+~~~
+ id | hour | mobile | userFeatures | clicked
+----|------|--------|------------------|---------
+ 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0
+~~~
+
+`userFeatures` is a vector column that contains three user features.
+We want to combine `hour`, `mobile`, and `userFeatures` into a single feature vector
+called `features` and use it to predict `clicked` or not.
+If we set `VectorAssembler`'s input columns to `hour`, `mobile`, and `userFeatures` and
+output column to `features`, after transformation we should get the following DataFrame:
+
+~~~
+ id | hour | mobile | userFeatures | clicked | features
+----|------|--------|------------------|---------|-----------------------------
+ 0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 | [18.0, 1.0, 0.0, 10.0, 0.5]
+~~~
+
+
# Feature Selectors
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index c5f50ed7990f1..4eb622d4b95e8 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -207,7 +207,7 @@ val model1 = lr.fit(training.toDF)
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
-println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+println("Model 1 was fit using parameters: " + model1.parent.extractParamMap)
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -222,7 +222,7 @@ val paramMapCombined = paramMap ++ paramMap2
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF, paramMapCombined)
-println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+println("Model 2 was fit using parameters: " + model2.parent.extractParamMap)
// Prepare test data.
val test = sc.parallelize(Seq(
@@ -289,7 +289,7 @@ LogisticRegressionModel model1 = lr.fit(training);
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
-System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
@@ -305,7 +305,7 @@ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
-System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
List localTest = Lists.newArrayList(
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index f41ca70952eb7..1b088969ddc25 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -47,7 +47,7 @@ Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasin
optimal *k* is usually one where there is an "elbow" in the WSSSE graph.
{% highlight scala %}
-import org.apache.spark.mllib.clustering.KMeans
+import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -62,6 +62,10 @@ val clusters = KMeans.train(parsedData, numClusters, numIterations)
// Evaluate clustering by computing Within Set Sum of Squared Errors
val WSSSE = clusters.computeCost(parsedData)
println("Within Set Sum of Squared Errors = " + WSSSE)
+
+// Save and load model
+clusters.save(sc, "myModelPath")
+val sameModel = KMeansModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -110,6 +114,10 @@ public class KMeansExample {
// Evaluate clustering by computing Within Set Sum of Squared Errors
double WSSSE = clusters.computeCost(parsedData.rdd());
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
+
+ // Save and load model
+ clusters.save(sc.sc(), "myModelPath");
+ KMeansModel sameModel = KMeansModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -124,7 +132,7 @@ Within Set Sum of Squared Error (WSSSE). You can reduce this error measure by in
fact the optimal *k* is usually one where there is an "elbow" in the WSSSE graph.
{% highlight python %}
-from pyspark.mllib.clustering import KMeans
+from pyspark.mllib.clustering import KMeans, KMeansModel
from numpy import array
from math import sqrt
@@ -143,6 +151,10 @@ def error(point):
WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y)
print("Within Set Sum of Squared Error = " + str(WSSSE))
+
+# Save and load model
+clusters.save(sc, "myModelPath")
+sameModel = KMeansModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -237,11 +249,11 @@ public class GaussianMixtureExample {
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
// Save and load GaussianMixtureModel
- gmm.save(sc, "myGMMModel")
- GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
+ gmm.save(sc.sc(), "myGMMModel");
+ GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel");
// Output the parameters of the mixture model
for(int j=0; j
println(s"${a.id} -> ${a.cluster}")
}
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath")
{% endhighlight %}
A full example that produces the experiment described in the PIC paper can be found under
@@ -360,6 +376,10 @@ PowerIterationClusteringModel model = pic.run(similarities);
for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) {
System.out.println(a.id() + " -> " + a.cluster());
}
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index 7b397e30b2d90..dfdf6216b270c 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -107,7 +107,8 @@ other signals), you can use the `trainImplicit` method to get better results.
{% highlight scala %}
val alpha = 0.01
-val model = ALS.trainImplicit(ratings, rank, numIterations, alpha)
+val lambda = 0.01
+val model = ALS.trainImplicit(ratings, rank, numIterations, lambda, alpha)
{% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index f723cd6b9dfab..4fe470a8de810 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th
import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.feature.Word2Vec
+import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
val input = sc.textFile("text8").map(line => line.split(" ").toSeq)
@@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40)
for((synonym, cosineSimilarity) <- synonyms) {
println(s"$synonym $cosineSimilarity")
}
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = Word2VecModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -410,6 +414,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.feature.ChiSqSelector
// Load some data in libsvm format
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
@@ -505,7 +510,7 @@ v_N
### Example
-This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value.
+This example below demonstrates how to transform vectors using a transforming vector value.
@@ -514,16 +519,44 @@ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.feature.ElementwiseProduct
import org.apache.spark.mllib.linalg.Vectors
-// Load and parse the data:
-val data = sc.textFile("data/mllib/kmeans_data.txt")
-val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))
+// Create some vector data; also works for sparse vectors
+val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0)))
val transformingVector = Vectors.dense(0.0, 1.0, 2.0)
val transformer = new ElementwiseProduct(transformingVector)
// Batch transform and per-row transform give the same results:
-val transformedData = transformer.transform(parsedData)
-val transformedData2 = parsedData.map(x => transformer.transform(x))
+val transformedData = transformer.transform(data)
+val transformedData2 = data.map(x => transformer.transform(x))
+
+{% endhighlight %}
+
+
+
+{% highlight java %}
+import java.util.Arrays;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.feature.ElementwiseProduct;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+// Create some vector data; also works for sparse vectors
+JavaRDD data = sc.parallelize(Arrays.asList(
+ Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0)));
+Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0);
+ElementwiseProduct transformer = new ElementwiseProduct(transformingVector);
+
+// Batch transform and per-row transform give the same results:
+JavaRDD transformedData = transformer.transform(data);
+JavaRDD transformedData2 = data.map(
+ new Function() {
+ @Override
+ public Vector call(Vector v) {
+ return transformer.transform(v);
+ }
+ }
+);
{% endhighlight %}
diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md
index 9fd9be0dd01b1..bcc066a185526 100644
--- a/docs/mllib-frequent-pattern-mining.md
+++ b/docs/mllib-frequent-pattern-mining.md
@@ -39,11 +39,11 @@ MLlib's FP-growth implementation takes the following (hyper-)parameters:
-[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the
+[`FPGrowth`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowth) implements the
FP-growth algorithm.
It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type.
Calling `FPGrowth.run` with transactions returns an
-[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html)
+[`FPGrowthModel`](api/scala/index.html#org.apache.spark.mllib.fpm.FPGrowthModel)
that stores the frequent itemsets with their frequencies.
{% highlight scala %}
diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md
index b521c2f27cd6e..5732bc4c7e79e 100644
--- a/docs/mllib-isotonic-regression.md
+++ b/docs/mllib-isotonic-regression.md
@@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f
labels and real labels in the test set.
{% highlight scala %}
-import org.apache.spark.mllib.regression.IsotonicRegression
+import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel}
val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt")
@@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point =>
// Calculate mean squared error between predicted and real labels.
val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean()
println("Mean Squared Error = " + meanSquaredError)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = IsotonicRegressionModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map(
).rdd()).mean();
System.out.println("Mean Squared Error = " + meanSquaredError);
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 8029edca16002..3dc8cc902fa72 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -163,11 +163,8 @@ object, and make predictions with the resulting model to compute the training
error.
{% highlight scala %}
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
// Load training data in LIBSVM format.
@@ -231,15 +228,13 @@ calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given bellow:
{% highlight java %}
-import java.util.Random;
-
import scala.Tuple2;
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.*;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
-import org.apache.spark.mllib.linalg.Vector;
+
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;
@@ -282,8 +277,8 @@ public class SVMClassifier {
System.out.println("Area under ROC = " + auROC);
// Save and load model
- model.save(sc.sc(), "myModelPath");
- SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath");
+ model.save(sc, "myModelPath");
+ SVMModel sameModel = SVMModel.load(sc, "myModelPath");
}
}
{% endhighlight %}
@@ -315,15 +310,12 @@ a dependency.
-The following example shows how to load a sample dataset, build Logistic Regression model,
+The following example shows how to load a sample dataset, build SVM model,
and make predictions with the resulting model to compute the training error.
-Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.classification import LogisticRegressionWithSGD
+from pyspark.mllib.classification import SVMWithSGD, SVMModel
from pyspark.mllib.regression import LabeledPoint
-from numpy import array
# Load and parse the data
def parsePoint(line):
@@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt")
parsedData = data.map(parsePoint)
# Build the model
-model = LogisticRegressionWithSGD.train(parsedData)
+model = SVMWithSGD.train(parsedData, iterations=100)
# Evaluating the model on training data
labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features)))
trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count())
print("Training Error = " + str(trainErr))
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = SVMModel.load(sc, "myModelPath")
{% endhighlight %}
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 56a2e9ca86bb1..bf6d124fd5d8d 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -14,9 +14,8 @@ and use it for prediction.
MLlib supports [multinomial naive
Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
-and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
-These models are typically used for [document classification]
-(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
+and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
+These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
Within that context, each observation is a document and each
feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or
a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
@@ -54,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)
-val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial")
+val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial")
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
diff --git a/docs/monitoring.md b/docs/monitoring.md
index e75018499003a..bcf885fe4e681 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -228,6 +228,14 @@ for a running application, at `http://localhost:4040/api/v1`.
/applications/[app-id]/storage/rdd/[rdd-id]
Details for the storage status of a given RDD
+
+
/applications/[app-id]/logs
+
Download the event logs for all attempts of the given application as a zip file
+
+
+
/applications/[app-id]/[attempt-id]/logs
+
Download the event logs for the specified attempt of the given application as a zip file
+
When running on Yarn, each application has multiple attempts, so `[app-id]` is actually
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 10f474f237bfa..d5ff416fe89a4 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -54,7 +54,7 @@ import org.apache.spark.SparkConf
-Spark {{site.SPARK_VERSION}} works with Java 6 and higher. If you are using Java 8, Spark supports
+Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java 8, Spark supports
[lambda expressions](http://docs.oracle.com/javase/tutorial/java/javaOO/lambdaexpressions.html)
for concisely writing functions, otherwise you can use the classes in the
[org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package.
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 9d55f435e80ad..96cf612c54fdd 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -242,6 +242,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
running against earlier versions, this property will be ignored.
+
+
spark.yarn.keytab
+
(none)
+
+ The full path to the file that contains the keytab for the principal specified above.
+ This keytab will be copied to the node running the Application Master via the Secure Distributed Cache,
+ for renewing the login tickets and the delegation tokens periodically.
+
+
+
+
spark.yarn.principal
+
(none)
+
+ Principal to be used to login to KDC, while running on secure HDFS.
+
+
# Launching Spark on YARN
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 0eed9adacf123..12d7d6e159bea 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -77,7 +77,7 @@ Note, the master machine accesses each of the worker machines via ssh. By defaul
If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker.
-Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`:
+Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/sbin`:
- `sbin/start-master.sh` - Starts a master instance on the machine the script is executed on.
- `sbin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file.
diff --git a/docs/sparkr.md b/docs/sparkr.md
new file mode 100644
index 0000000000000..4d82129921a37
--- /dev/null
+++ b/docs/sparkr.md
@@ -0,0 +1,223 @@
+---
+layout: global
+displayTitle: SparkR (R on Spark)
+title: SparkR (R on Spark)
+---
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+# Overview
+SparkR is an R package that provides a light-weight frontend to use Apache Spark from R.
+In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that
+supports operations like selection, filtering, aggregation etc. (similar to R data frames,
+[dplyr](https://github.com/hadley/dplyr)) but on large datasets.
+
+# SparkR DataFrames
+
+A DataFrame is a distributed collection of data organized into named columns. It is conceptually
+equivalent to a table in a relational database or a data frame in R, but with richer
+optimizations under the hood. DataFrames can be constructed from a wide array of sources such as:
+structured data files, tables in Hive, external databases, or existing local R data frames.
+
+All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell.
+
+## Starting Up: SparkContext, SQLContext
+
+
+The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster.
+You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name
+etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the
+SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should
+already be created for you.
+
+{% highlight r %}
+sc <- sparkR.init()
+sqlContext <- sparkRSQL.init(sc)
+{% endhighlight %}
+
+
+
+## Creating DataFrames
+With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources).
+
+### From local data frames
+The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R.
+
+
+{% highlight r %}
+df <- createDataFrame(sqlContext, faithful)
+
+# Displays the content of the DataFrame to stdout
+head(df)
+## eruptions waiting
+##1 3.600 79
+##2 1.800 54
+##3 3.333 74
+
+{% endhighlight %}
+
+
+### From Data Sources
+
+SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources.
+
+The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro).
+
+We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail.
+
+
+
+{% highlight r %}
+people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json")
+head(people)
+## age name
+##1 NA Michael
+##2 30 Andy
+##3 19 Justin
+
+# SparkR automatically infers the schema from the JSON file
+printSchema(people)
+# root
+# |-- age: integer (nullable = true)
+# |-- name: string (nullable = true)
+
+{% endhighlight %}
+
+
+The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example
+to a Parquet file using `write.df`
+
+
+
+### From Hive tables
+
+You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext).
+
+
+{% highlight r %}
+# sc is an existing SparkContext.
+hiveContext <- sparkRHive.init(sc)
+
+sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+
+# Queries can be expressed in HiveQL.
+results <- hiveContext.sql("FROM src SELECT key, value")
+
+# results is now a DataFrame
+head(results)
+## key value
+## 1 238 val_238
+## 2 86 val_86
+## 3 311 val_311
+
+{% endhighlight %}
+
+
+## DataFrame Operations
+
+SparkR DataFrames support a number of functions to do structured data processing.
+Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs:
+
+### Selecting rows, columns
+
+
+{% highlight r %}
+# Create the DataFrame
+df <- createDataFrame(sqlContext, faithful)
+
+# Get basic information about the DataFrame
+df
+## DataFrame[eruptions:double, waiting:double]
+
+# Select only the "eruptions" column
+head(select(df, df$eruptions))
+## eruptions
+##1 3.600
+##2 1.800
+##3 3.333
+
+# You can also pass in column name as strings
+head(select(df, "eruptions"))
+
+# Filter the DataFrame to only retain rows with wait times shorter than 50 mins
+head(filter(df, df$waiting < 50))
+## eruptions waiting
+##1 1.750 47
+##2 1.750 47
+##3 1.867 48
+
+{% endhighlight %}
+
+
+
+### Grouping, Aggregation
+
+SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
+
+
+{% highlight r %}
+
+# We use the `n` operator to count the number of times each waiting time appears
+head(summarize(groupBy(df, df$waiting), count = n(df$waiting)))
+## waiting count
+##1 81 13
+##2 60 6
+##3 68 1
+
+# We can also sort the output from the aggregation to get the most common waiting times
+waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting))
+head(arrange(waiting_counts, desc(waiting_counts$count)))
+
+## waiting count
+##1 78 15
+##2 83 14
+##3 81 13
+
+{% endhighlight %}
+
+
+### Operating on Columns
+
+SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions.
+
+
+{% highlight r %}
+
+# Convert waiting time from hours to seconds.
+# Note that we can assign this to a new column in the same DataFrame
+df$waiting_secs <- df$waiting * 60
+head(df)
+## eruptions waiting waiting_secs
+##1 3.600 79 4740
+##2 1.800 54 3240
+##3 3.333 74 4440
+
+{% endhighlight %}
+
+
+## Running SQL Queries from SparkR
+A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data.
+The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`.
+
+
+{% highlight r %}
+# Load a JSON file
+people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json")
+
+# Register this DataFrame as a table.
+registerTempTable(people, "people")
+
+# SQL statements can be run by using the sql method
+teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19")
+head(teenagers)
+## name
+##1 Justin
+
+{% endhighlight %}
+
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 5b41c0ee6e430..40e33f757d693 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -11,6 +11,7 @@ title: Spark SQL and DataFrames
Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine.
+For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section.
# DataFrames
@@ -108,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO
val sc: SparkContext // An existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-val df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+val df = sqlContext.read.json("examples/src/main/resources/people.json")
// Displays the content of the DataFrame to stdout
df.show()
@@ -121,7 +122,7 @@ df.show()
JavaSparkContext sc = ...; // An existing JavaSparkContext.
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
-DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json");
+DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json");
// Displays the content of the DataFrame to stdout
df.show();
@@ -134,7 +135,7 @@ df.show();
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
-df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+df = sqlContext.read.json("examples/src/main/resources/people.json")
# Displays the content of the DataFrame to stdout
df.show()
@@ -170,7 +171,7 @@ val sc: SparkContext // An existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
// Create the DataFrame
-val df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+val df = sqlContext.read.json("examples/src/main/resources/people.json")
// Show the content of the DataFrame
df.show()
@@ -220,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext.
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc)
// Create the DataFrame
-DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json");
+DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json");
// Show the content of the DataFrame
df.show();
@@ -276,7 +277,7 @@ from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
# Create the DataFrame
-df = sqlContext.jsonFile("examples/src/main/resources/people.json")
+df = sqlContext.read.json("examples/src/main/resources/people.json")
# Show the content of the DataFrame
df.show()
@@ -776,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
Ignore mode means that when saving a DataFrame to a data source, if data already exists,
the save operation is expected to not save the contents of the DataFrame and to not
- change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL.
+ change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL.
@@ -946,11 +947,11 @@ import sqlContext.implicits._
val people: RDD[Person] = ... // An RDD of case class objects, from the previous example.
// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet.
-people.saveAsParquetFile("people.parquet")
+people.write.parquet("people.parquet")
// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a Parquet file is also a DataFrame.
-val parquetFile = sqlContext.parquetFile("people.parquet")
+val parquetFile = sqlContext.read.parquet("people.parquet")
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile")
@@ -968,11 +969,11 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
DataFrame schemaPeople = ... // The DataFrame from the previous example.
// DataFrames can be saved as Parquet files, maintaining the schema information.
-schemaPeople.saveAsParquetFile("people.parquet");
+schemaPeople.write().parquet("people.parquet");
// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
// The result of loading a parquet file is also a DataFrame.
-DataFrame parquetFile = sqlContext.parquetFile("people.parquet");
+DataFrame parquetFile = sqlContext.read().parquet("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
@@ -994,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function()
schemaPeople # The DataFrame from the previous example.
# DataFrames can be saved as Parquet files, maintaining the schema information.
-schemaPeople.saveAsParquetFile("people.parquet")
+schemaPeople.read.parquet("people.parquet")
# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
# The result of loading a parquet file is also a DataFrame.
-parquetFile = sqlContext.parquetFile("people.parquet")
+parquetFile = sqlContext.write.parquet("people.parquet")
# Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
@@ -1030,7 +1031,7 @@ teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND a
teenNames <- map(teenagers, function(p) { paste("Name:", p$name)})
for (teenName in collect(teenNames)) {
cat(teenName, "\n")
-}
+}
{% endhighlight %}
@@ -1086,9 +1087,9 @@ path
{% endhighlight %}
-By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will
-automatically extract the partitioning information from the paths. Now the schema of the returned
-DataFrame becomes:
+By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL
+will automatically extract the partitioning information from the paths.
+Now the schema of the returned DataFrame becomes:
{% highlight text %}
@@ -1101,7 +1102,11 @@ root
{% endhighlight %}
Notice that the data types of the partitioning columns are automatically inferred. Currently,
-numeric data types and string type are supported.
+numeric data types and string type are supported. Sometimes users may not want to automatically
+infer the data types of the partitioning columns. For these use cases, the automatic type inference
+can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, which is default to
+`true`. When type inference is disabled, string type will be used for the partitioning columns.
+
### Schema merging
@@ -1121,15 +1126,15 @@ import sqlContext.implicits._
// Create a simple DataFrame, stored into a partition directory
val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double")
-df1.saveAsParquetFile("data/test_table/key=1")
+df1.write.parquet("data/test_table/key=1")
// Create another DataFrame in a new partition directory,
// adding a new column and dropping an existing column
val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple")
-df2.saveAsParquetFile("data/test_table/key=2")
+df2.write.parquet("data/test_table/key=2")
// Read the partitioned table
-val df3 = sqlContext.parquetFile("data/test_table")
+val df3 = sqlContext.read.parquet("data/test_table")
df3.printSchema()
// The final schema consists of all 3 columns in the Parquet files together
@@ -1268,12 +1273,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
+This conversion can be done using `SQLContext.read.json()` on either an RDD of String,
+or a JSON file.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1284,8 +1287,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
val path = "examples/src/main/resources/people.json"
-// Create a DataFrame from the file(s) pointed to by path
-val people = sqlContext.jsonFile(path)
+val people = sqlContext.read.json(path)
// The inferred schema can be visualized using the printSchema() method.
people.printSchema()
@@ -1303,19 +1305,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age
// an RDD[String] storing one JSON object per string.
val anotherPeopleRDD = sc.parallelize(
"""{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil)
-val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
+val anotherPeople = sqlContext.read.json(anotherPeopleRDD)
{% endhighlight %}
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext` :
+This conversion can be done using `SQLContext.read().json()` on either an RDD of String,
+or a JSON file.
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
-
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1325,9 +1325,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
-String path = "examples/src/main/resources/people.json";
-// Create a DataFrame from the file(s) pointed to by path
-DataFrame people = sqlContext.jsonFile(path);
+DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json");
// The inferred schema can be visualized using the printSchema() method.
people.printSchema();
@@ -1346,18 +1344,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN
List jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD anotherPeopleRDD = sc.parallelize(jsonData);
-DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD);
+DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD);
{% endhighlight %}
Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
-* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
+This conversion can be done using `SQLContext.read.json` on a JSON file.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1368,9 +1363,7 @@ sqlContext = SQLContext(sc)
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
-path = "examples/src/main/resources/people.json"
-# Create a DataFrame from the file(s) pointed to by path
-people = sqlContext.jsonFile(path)
+people = sqlContext.read.json("examples/src/main/resources/people.json")
# The inferred schema can be visualized using the printSchema() method.
people.printSchema()
@@ -1393,12 +1386,11 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
-This conversion can be done using one of two methods in a `SQLContext`:
-
-* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using
+the `jsonFile` function, which loads data from a directory of JSON files where each line of the
+files is a JSON object.
-Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each
+Note that the file that is offered as _a json file_ is not a typical JSON file. Each
line must contain a separate, self-contained valid JSON object. As a consequence,
a regular multi-line JSON file will most often fail.
@@ -1502,7 +1494,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and
-adds support for finding tables in the MetaStore and writing queries using HiveQL.
+adds support for finding tables in the MetaStore and writing queries using HiveQL.
{% highlight python %}
# sc is an existing SparkContext.
from pyspark.sql import HiveContext
@@ -1526,8 +1518,8 @@ adds support for finding tables in the MetaStore and writing queries using HiveQ
# sc is an existing SparkContext.
sqlContext <- sparkRHive.init(sc)
-hql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
-hql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
# Queries can be expressed in HiveQL.
results = sqlContext.sql("FROM src SELECT key, value").collect()
@@ -1537,6 +1529,70 @@ results = sqlContext.sql("FROM src SELECT key, value").collect()
+### Interacting with Different Versions of Hive Metastore
+
+One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore,
+which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below.
+
+Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET`
+and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive
+jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the
+version specified by users. An isolated classloader is used here to avoid dependency conflicts.
+
+
+
Property Name
Default
Meaning
+
+
spark.sql.hive.metastore.version
+
0.13.1
+
+ Version of the Hive metastore. Available
+ options are 0.12.0 and 0.13.1. Support for more versions is coming in the future.
+
+
+
+
spark.sql.hive.metastore.jars
+
builtin
+
+ Location of the jars that should be used to instantiate the HiveMetastoreClient. This
+ property can be one of three options:
+
+
builtin
+ Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is
+ enabled. When this option is chosen, spark.sql.hive.metastore.version must be
+ either 0.13.1 or not defined.
+
maven
+ Use Hive jars of specified version downloaded from Maven repositories.
+
A classpath in the standard format for both Hive and Hadoop.
+ A comma separated list of class prefixes that should be loaded using the classloader that is
+ shared between Spark SQL and a specific version of Hive. An example of classes that should
+ be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need
+ to be shared are those that interact with classes that are already shared. For example,
+ custom appenders that are used by log4j.
+
+
+
+
+
spark.sql.hive.metastore.barrierPrefixes
+
(empty)
+
+
+ A comma separated list of class prefixes that should explicitly be reloaded for each version
+ of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a
+ prefix that typically would be shared (i.e. org.apache.spark.*).
+
+
+
+
+
+
## JDBC To Other Databases
Spark SQL also includes a data source that can read data from other databases using JDBC. This
@@ -1570,7 +1626,7 @@ the Data Sources API. The following options are supported:
dbtable
- The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of
+ The JDBC table that should be read. Note that anything that is valid in a FROM clause of
a SQL query can be used. For example, instead of a full table you could also use a
subquery in parentheses.
@@ -1714,7 +1770,7 @@ that these options will be deprecated in future release as more optimizations ar
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
statistics are only supported for Hive Metastore tables where the command
- `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run.
+ ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.
@@ -1733,11 +1789,20 @@ that these options will be deprecated in future release as more optimizations ar
Configures the number of partitions to use when shuffling data for joins or aggregations.
+
+
spark.sql.planner.externalSort
+
false
+
+ When true, performs sorts spilling to disk as needed otherwise sort each partition in memory.
+
+
# Distributed SQL Engine
-Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code.
+Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface.
+In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries,
+without the need to write any code.
## Running the Thrift JDBC/ODBC server
@@ -1816,6 +1881,25 @@ options.
## Upgrading from Spark SQL 1.3 to 1.4
+#### DataFrame data reader/writer interface
+
+Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`)
+and writing data out (`DataFrame.write`),
+and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`).
+
+See the API docs for `SQLContext.read` (
+ Scala,
+ Java,
+ Python
+) and `DataFrame.write` (
+ Scala,
+ Java,
+ Python
+) more information.
+
+
+#### DataFrame.groupBy retains grouping columns
+
Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md
index 64714f0b799fc..998c8c994e4b4 100644
--- a/docs/streaming-kafka-integration.md
+++ b/docs/streaming-kafka-integration.md
@@ -7,7 +7,7 @@ title: Spark Streaming + Kafka Integration Guide
## Approach 1: Receiver-based Approach
This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data.
-However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming. To ensure zero data loss, enable the Write Ahead Logs (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs.
+However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs.
Next, we discuss how to use this approach in your streaming application.
@@ -29,7 +29,7 @@ Next, we discuss how to use this approach in your streaming application.
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume])
You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*;
@@ -39,7 +39,7 @@ Next, we discuss how to use this approach in your streaming application.
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]);
You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
@@ -105,7 +105,7 @@ Next, we discuss how to use this approach in your streaming application.
streamingContext, [map of Kafka parameters], [set of topics to consume])
See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
import org.apache.spark.streaming.kafka.*;
@@ -116,7 +116,7 @@ Next, we discuss how to use this approach in your streaming application.
[map of Kafka parameters], [set of topics to consume]);
See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
@@ -153,4 +153,4 @@ Next, we discuss how to use this approach in your streaming application.
Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API.
-3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation.
\ No newline at end of file
+3. **Deploying:** Similar to the first approach, you can package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR and the launch the application using `spark-submit`. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index bd863d48d53e3..42b33947873b0 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -1946,10 +1946,10 @@ creates a single receiver (running on a worker machine) that receives a single s
Receiving multiple data streams can therefore be achieved by creating multiple input DStreams
and configuring them to receive different partitions of the data stream from the source(s).
For example, a single Kafka input DStream receiving two topics of data can be split into two
-Kafka input streams, each receiving only one topic. This would run two receivers on two workers,
-thus allowing data to be received in parallel, and increasing overall throughput. These multiple
-DStream can be unioned together to create a single DStream. Then the transformations that was
-being applied on the single input DStream can applied on the unified stream. This is done as follows.
+Kafka input streams, each receiving only one topic. This would run two receivers,
+allowing data to be received in parallel, and increasing overall throughput. These multiple
+DStreams can be unioned together to create a single DStream. Then the transformations that were
+being applied on a single input DStream can be applied on the unified stream. This is done as follows.
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index c6d5a1f0d0a81..84629cb9a0ca0 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -19,8 +19,9 @@
# limitations under the License.
#
-from __future__ import with_statement, print_function
+from __future__ import division, print_function, with_statement
+import codecs
import hashlib
import itertools
import logging
@@ -47,6 +48,8 @@
else:
from urllib.request import urlopen, Request
from urllib.error import HTTPError
+ raw_input = input
+ xrange = range
SPARK_EC2_VERSION = "1.3.1"
SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
@@ -216,7 +219,8 @@ def parse_args():
"(default: %default).")
parser.add_option(
"--hadoop-major-version", default="1",
- help="Major version of Hadoop (default: %default)")
+ help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " +
+ "(Hadoop 2.4.0) (default: %default)")
parser.add_option(
"-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
help="Use SSH dynamic port forwarding to create a SOCKS proxy at " +
@@ -268,7 +272,8 @@ def parse_args():
help="Launch fresh slaves, but use an existing stopped master if possible")
parser.add_option(
"--worker-instances", type="int", default=1,
- help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)")
+ help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " +
+ "is used as Hadoop major version (default: %default)")
parser.add_option(
"--master-opts", type="string", default="",
help="Extra options to give to master through SPARK_MASTER_OPTS variable " +
@@ -423,13 +428,14 @@ def get_spark_ami(opts):
b=opts.spark_ec2_git_branch)
ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type)
+ reader = codecs.getreader("ascii")
try:
- ami = urlopen(ami_path).read().strip()
- print("Spark AMI: " + ami)
+ ami = reader(urlopen(ami_path)).read().strip()
except:
print("Could not resolve AMI at: " + ami_path, file=stderr)
sys.exit(1)
+ print("Spark AMI: " + ami)
return ami
@@ -487,6 +493,8 @@ def launch_cluster(conn, opts, cluster_name):
master_group.authorize('udp', 2049, 2049, authorized_address)
master_group.authorize('tcp', 4242, 4242, authorized_address)
master_group.authorize('udp', 4242, 4242, authorized_address)
+ # RM in YARN mode uses 8088
+ master_group.authorize('tcp', 8088, 8088, authorized_address)
if opts.ganglia:
master_group.authorize('tcp', 5080, 5080, authorized_address)
if slave_group.rules == []: # Group was just now created
@@ -750,11 +758,15 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
'mapreduce', 'spark-standalone', 'tachyon']
if opts.hadoop_major_version == "1":
- modules = filter(lambda x: x != "mapreduce", modules)
+ modules = list(filter(lambda x: x != "mapreduce", modules))
if opts.ganglia:
modules.append('ganglia')
+ # Clear SPARK_WORKER_INSTANCES if running on YARN
+ if opts.hadoop_major_version == "yarn":
+ opts.worker_instances = ""
+
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format(
@@ -992,6 +1004,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes]
slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes]
+ worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else ""
template_vars = {
"master_list": '\n'.join(master_addresses),
"active_master": active_master,
@@ -1005,7 +1018,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
"spark_version": spark_v,
"tachyon_version": tachyon_v,
"hadoop_major_version": opts.hadoop_major_version,
- "spark_worker_instances": "%d" % opts.worker_instances,
+ "spark_worker_instances": worker_instances_str,
"spark_master_opts": opts.master_opts
}
@@ -1160,7 +1173,7 @@ def get_zones(conn, opts):
# Gets the number of items in a partition
def get_partition(total, num_partitions, current_partitions):
- num_slaves_this_zone = total / num_partitions
+ num_slaves_this_zone = total // num_partitions
if (total % num_partitions) - current_partitions > 0:
num_slaves_this_zone += 1
return num_slaves_this_zone
diff --git a/examples/pom.xml b/examples/pom.xml
index 5b04b4f8d6ca0..e6884b09dca94 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -97,6 +97,11 @@
+
+ org.apache.spark
+ spark-streaming-kafka_${scala.binary.version}
+ ${project.version}
+ org.apache.hbasehbase-testing-util
@@ -392,45 +397,6 @@
-
-
- scala-2.10
-
- !scala-2.11
-
-
-
- org.apache.spark
- spark-streaming-kafka_${scala.binary.version}
- ${project.version}
-
-
-
-
-
- org.codehaus.mojo
- build-helper-maven-plugin
-
-
- add-scala-sources
- generate-sources
-
- add-source
-
-
-
- src/main/scala
- scala-2.10/src/main/scala
- scala-2.10/src/main/java
-
-
-
-
-
-
-
-
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 29158d5c85651..dac649d1d5ae6 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -97,7 +97,7 @@ public static void main(String[] args) {
DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
DataFrame results = model2.transform(test);
diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
similarity index 100%
rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
similarity index 100%
rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py
new file mode 100644
index 0000000000000..f0ca97c724940
--- /dev/null
+++ b/examples/src/main/python/ml/cross_validator.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.ml import Pipeline
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.evaluation import BinaryClassificationEvaluator
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating model selection using CrossValidator.
+This example also demonstrates how Pipelines are Estimators.
+Run with:
+
+ bin/spark-submit examples/src/main/python/ml/cross_validator.py
+"""
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="CrossValidatorExample")
+ sqlContext = SQLContext(sc)
+
+ # Prepare training documents, which are labeled.
+ LabeledDocument = Row("id", "text", "label")
+ training = sc.parallelize([(0, "a b c d e spark", 1.0),
+ (1, "b d", 0.0),
+ (2, "spark f g h", 1.0),
+ (3, "hadoop mapreduce", 0.0),
+ (4, "b spark who", 1.0),
+ (5, "g d a y", 0.0),
+ (6, "spark fly", 1.0),
+ (7, "was mapreduce", 0.0),
+ (8, "e spark program", 1.0),
+ (9, "a e c l", 0.0),
+ (10, "spark compile", 1.0),
+ (11, "hadoop software", 0.0)
+ ]) \
+ .map(lambda x: LabeledDocument(*x)).toDF()
+
+ # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
+ tokenizer = Tokenizer(inputCol="text", outputCol="words")
+ hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
+ lr = LogisticRegression(maxIter=10)
+ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+
+ # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
+ # This will allow us to jointly choose parameters for all Pipeline stages.
+ # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ # We use a ParamGridBuilder to construct a grid of parameters to search over.
+ # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
+ # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
+ paramGrid = ParamGridBuilder() \
+ .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
+ .addGrid(lr.regParam, [0.1, 0.01]) \
+ .build()
+
+ crossval = CrossValidator(estimator=pipeline,
+ estimatorParamMaps=paramGrid,
+ evaluator=BinaryClassificationEvaluator(),
+ numFolds=2) # use 3+ folds in practice
+
+ # Run cross-validation, and choose the best set of parameters.
+ cvModel = crossval.fit(training)
+
+ # Prepare test documents, which are unlabeled.
+ Document = Row("id", "text")
+ test = sc.parallelize([(4L, "spark i j k"),
+ (5L, "l m n"),
+ (6L, "mapreduce spark"),
+ (7L, "apache hadoop")]) \
+ .map(lambda x: Document(*x)).toDF()
+
+ # Make predictions on test documents. cvModel uses the best model found (lrModel).
+ prediction = cvModel.transform(test)
+ selected = prediction.select("id", "text", "probability", "prediction")
+ for row in selected.collect():
+ print(row)
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py
new file mode 100644
index 0000000000000..6446f0fe5eeab
--- /dev/null
+++ b/examples/src/main/python/ml/gradient_boosted_trees.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import GBTClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import GBTRegressor
+from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline.
+Note: GBTClassifier only supports binary classification currently
+Run with:
+ bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py
+"""
+
+
+def testClassification(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = BinaryClassificationMetrics(predictionAndLabels)
+ print("AUC %.3f" % metrics.areaUnderROC)
+
+
+def testRegression(train, test):
+ # Train a GradientBoostedTrees model.
+
+ rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel")
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: gradient_boosted_trees", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonGBTExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py
new file mode 100644
index 0000000000000..c7730e1bfacd9
--- /dev/null
+++ b/examples/src/main/python/ml/random_forest_example.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import RandomForestClassifier
+from pyspark.ml.feature import StringIndexer
+from pyspark.ml.regression import RandomForestRegressor
+from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics
+from pyspark.mllib.util import MLUtils
+from pyspark.sql import Row, SQLContext
+
+"""
+A simple example demonstrating a RandomForest Classification/Regression Pipeline.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/random_forest_example.py
+"""
+
+
+def testClassification(train, test):
+ # Train a RandomForest model.
+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = MulticlassMetrics(predictionAndLabels)
+ print("weighted f-measure %.3f" % metrics.weightedFMeasure())
+ print("precision %s" % metrics.precision())
+ print("recall %s" % metrics.recall())
+
+
+def testRegression(train, test):
+ # Train a RandomForest model.
+ # Note: Use larger numTrees in practice.
+
+ rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4)
+
+ model = rf.fit(train)
+ predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \
+ .map(lambda x: (x.prediction, x.indexedLabel))
+
+ metrics = RegressionMetrics(predictionAndLabels)
+ print("rmse %.3f" % metrics.rootMeanSquaredError)
+ print("r2 %.3f" % metrics.r2)
+ print("mae %.3f" % metrics.meanAbsoluteError)
+
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: random_forest_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonRandomForestExample")
+ sqlContext = SQLContext(sc)
+
+ # Load and parse the data file into a dataframe.
+ df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
+
+ # Map labels into an indexed column of labels in [0, numLabels)
+ stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
+ si_model = stringIndexer.fit(df)
+ td = si_model.transform(df)
+ [train, test] = td.randomSplit([0.7, 0.3])
+ testClassification(train, test)
+ testRegression(train, test)
+ sc.stop()
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
new file mode 100644
index 0000000000000..a9f29dab2d602
--- /dev/null
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import pprint
+import sys
+
+from pyspark import SparkContext
+from pyspark.ml.classification import LogisticRegression
+from pyspark.mllib.linalg import DenseVector
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.sql import SQLContext
+
+"""
+A simple example demonstrating ways to specify parameters for Estimators and Transformers.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/simple_params_example.py
+"""
+
+if __name__ == "__main__":
+ if len(sys.argv) > 1:
+ print("Usage: simple_params_example", file=sys.stderr)
+ exit(1)
+ sc = SparkContext(appName="PythonSimpleParamsExample")
+ sqlContext = SQLContext(sc)
+
+ # prepare training data.
+ # We create an RDD of LabeledPoints and convert them into a DataFrame.
+ # A LabeledPoint is an Object with two fields named label and features
+ # and Spark SQL identifies these fields and creates the schema appropriately.
+ training = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
+ LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])),
+ LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF()
+
+ # Create a LogisticRegression instance with maxIter = 10.
+ # This instance is an Estimator.
+ lr = LogisticRegression(maxIter=10)
+ # Print out the parameters, documentation, and any default values.
+ print("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+ # We may also set parameters using setter methods.
+ lr.setRegParam(0.01)
+
+ # Learn a LogisticRegression model. This uses the parameters stored in lr.
+ model1 = lr.fit(training)
+
+ # Since model1 is a Model (i.e., a Transformer produced by an Estimator),
+ # we can view the parameters it used during fit().
+ # This prints the parameter (name: value) pairs, where names are unique IDs for this
+ # LogisticRegression instance.
+ print("Model 1 was fit using parameters:\n")
+ pprint.pprint(model1.extractParamMap())
+
+ # We may alternatively specify parameters using a parameter map.
+ # paramMap overrides all lr parameters set earlier.
+ paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"}
+
+ # Now learn a new model using the new parameters.
+ model2 = lr.fit(training, paramMap)
+ print("Model 2 was fit using parameters:\n")
+ pprint.pprint(model2.extractParamMap())
+
+ # prepare test data.
+ test = sc.parallelize([
+ LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])),
+ LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])),
+ LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF()
+
+ # Make predictions on test data using the Transformer.transform() method.
+ # LogisticRegressionModel.transform will only use the 'features' column.
+ # Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ # 'probability' column since we renamed the lr.probabilityCol parameter previously.
+ result = model2.transform(test) \
+ .select("features", "label", "myProbability", "prediction") \
+ .collect()
+
+ for row in result:
+ print("features=%s,label=%s -> prob=%s, prediction=%s"
+ % (row.features, row.label, row.myProbability, row.prediction))
+
+ sc.stop()
diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py
index 96ddac761d698..e1fd85b082c08 100644
--- a/examples/src/main/python/parquet_inputformat.py
+++ b/examples/src/main/python/parquet_inputformat.py
@@ -51,7 +51,7 @@
parquet_rdd = sc.newAPIHadoopFile(
path,
- 'parquet.avro.AvroParquetInputFormat',
+ 'org.apache.parquet.avro.AvroParquetInputFormat',
'java.lang.Void',
'org.apache.avro.generic.IndexedRecord',
valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter')
diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
index 11d5c92c5952d..023bb3ee2d108 100644
--- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
@@ -104,8 +104,8 @@ object CassandraCQLTest {
val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(),
classOf[CqlPagingInputFormat],
- classOf[java.util.Map[String,ByteBuffer]],
- classOf[java.util.Map[String,ByteBuffer]])
+ classOf[java.util.Map[String, ByteBuffer]],
+ classOf[java.util.Map[String, ByteBuffer]])
println("Count: " + casRdd.count)
val productSaleRDD = casRdd.map {
@@ -118,7 +118,7 @@ object CassandraCQLTest {
case (productId, saleCount) => println(productId + ":" + saleCount)
}
- val casoutputCF = aggregatedRDD.map {
+ val casoutputCF = aggregatedRDD.map {
case (productId, saleCount) => {
val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId))
val outKey: java.util.Map[String, ByteBuffer] = outColFamKey
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
index a55e0dc8d36c2..c3fc74a116c0a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
@@ -39,7 +39,7 @@ object LocalLR {
def generateData: Array[DataPoint] = {
def generatePoint(i: Int): DataPoint = {
- val y = if(i % 2 == 0) -1 else 1
+ val y = if (i % 2 == 0) -1 else 1
val x = DenseVector.fill(D){rand.nextGaussian + y * R}
DataPoint(x, y)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
index 32e02eab8b031..75c82117cbad2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
@@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._
/**
* Executes a roll up-style query against Apache logs.
- *
+ *
* Usage: LogQuery [logFile]
*/
object LogQuery {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
index 6c0ac8013ce34..30c4261551837 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
@@ -117,7 +117,7 @@ object SparkALS {
var us = Array.fill(U)(randomVector(F))
// Iteratively update movies then users
- val Rc = sc.broadcast(R)
+ val Rc = sc.broadcast(R)
var msb = sc.broadcast(ms)
var usb = sc.broadcast(us)
for (iter <- 1 to ITERATIONS) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
index 8c01a60844620..1e6b4fb0c7514 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
@@ -44,7 +44,7 @@ object SparkLR {
def generateData: Array[DataPoint] = {
def generatePoint(i: Int): DataPoint = {
- val y = if(i % 2 == 0) -1 else 1
+ val y = if (i % 2 == 0) -1 else 1
val x = DenseVector.fill(D){rand.nextGaussian + y * R}
DataPoint(x, y)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
index 8d092b6506d33..bd7894f184c4c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
@@ -51,7 +51,7 @@ object SparkPageRank {
showWarning()
val sparkConf = new SparkConf().setAppName("PageRank")
- val iters = if (args.length > 0) args(1).toInt else 10
+ val iters = if (args.length > 1) args(1).toInt else 10
val ctx = new SparkContext(sparkConf)
val lines = ctx.textFile(args(0), 1)
val links = lines.map{ s =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
deleted file mode 100644
index ab6e63deb3c95..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.bagel
-
-import org.apache.spark._
-import org.apache.spark.bagel._
-
-class PageRankUtils extends Serializable {
- def computeWithCombiner(numVertices: Long, epsilon: Double)(
- self: PRVertex, messageSum: Option[Double], superstep: Int
- ): (PRVertex, Array[PRMessage]) = {
- val newValue = messageSum match {
- case Some(msgSum) if msgSum != 0 =>
- 0.15 / numVertices + 0.85 * msgSum
- case _ => self.value
- }
-
- val terminate = superstep >= 10
-
- val outbox: Array[PRMessage] =
- if (!terminate) {
- self.outEdges.map(targetId => new PRMessage(targetId, newValue / self.outEdges.size))
- } else {
- Array[PRMessage]()
- }
-
- (new PRVertex(newValue, self.outEdges, !terminate), outbox)
- }
-
- def computeNoCombiner(numVertices: Long, epsilon: Double)
- (self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int)
- : (PRVertex, Array[PRMessage]) =
- computeWithCombiner(numVertices, epsilon)(self, messages match {
- case Some(msgs) => Some(msgs.map(_.value).sum)
- case None => None
- }, superstep)
-}
-
-class PRCombiner extends Combiner[PRMessage, Double] with Serializable {
- def createCombiner(msg: PRMessage): Double =
- msg.value
- def mergeMsg(combiner: Double, msg: PRMessage): Double =
- combiner + msg.value
- def mergeCombiners(a: Double, b: Double): Double =
- a + b
-}
-
-class PRVertex() extends Vertex with Serializable {
- var value: Double = _
- var outEdges: Array[String] = _
- var active: Boolean = _
-
- def this(value: Double, outEdges: Array[String], active: Boolean = true) {
- this()
- this.value = value
- this.outEdges = outEdges
- this.active = active
- }
-
- override def toString(): String = {
- "PRVertex(value=%f, outEdges.length=%d, active=%s)"
- .format(value, outEdges.length, active.toString)
- }
-}
-
-class PRMessage() extends Message[String] with Serializable {
- var targetId: String = _
- var value: Double = _
-
- def this(targetId: String, value: Double) {
- this()
- this.targetId = targetId
- this.value = value
- }
-}
-
-class CustomPartitioner(partitions: Int) extends Partitioner {
- def numPartitions: Int = partitions
-
- def getPartition(key: Any): Int = {
- val hash = key match {
- case k: Long => (k & 0x00000000FFFFFFFFL).toInt
- case _ => key.hashCode
- }
-
- val mod = key.hashCode % partitions
- if (mod < 0) mod + partitions else mod
- }
-
- override def equals(other: Any): Boolean = other match {
- case c: CustomPartitioner =>
- c.numPartitions == numPartitions
- case _ => false
- }
-
- override def hashCode: Int = numPartitions
-}
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
deleted file mode 100644
index 859abedf2a55e..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.bagel
-
-import org.apache.spark._
-import org.apache.spark.SparkContext._
-
-import org.apache.spark.bagel._
-
-import scala.xml.{XML,NodeSeq}
-
-/**
- * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles"
- * files from there, which contains one line per wiki article in a tab-separated format
- * (http://wiki.freebase.com/wiki/WEX/Documentation#articles).
- */
-object WikipediaPageRank {
- def main(args: Array[String]) {
- if (args.length < 4) {
- System.err.println(
- "Usage: WikipediaPageRank ")
- System.exit(-1)
- }
- val sparkConf = new SparkConf()
- sparkConf.setAppName("WikipediaPageRank")
- sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage]))
-
- val inputFile = args(0)
- val threshold = args(1).toDouble
- val numPartitions = args(2).toInt
- val usePartitioner = args(3).toBoolean
-
- sparkConf.setAppName("WikipediaPageRank")
- val sc = new SparkContext(sparkConf)
-
- // Parse the Wikipedia page data into a graph
- val input = sc.textFile(inputFile)
-
- println("Counting vertices...")
- val numVertices = input.count()
- println("Done counting vertices.")
-
- println("Parsing input file...")
- var vertices = input.map(line => {
- val fields = line.split("\t")
- val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
- val links =
- if (body == "\\N") {
- NodeSeq.Empty
- } else {
- try {
- XML.loadString(body) \\ "link" \ "target"
- } catch {
- case e: org.xml.sax.SAXParseException =>
- System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body)
- NodeSeq.Empty
- }
- }
- val outEdges = links.map(link => new String(link.text)).toArray
- val id = new String(title)
- (id, new PRVertex(1.0 / numVertices, outEdges))
- })
- if (usePartitioner) {
- vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache()
- } else {
- vertices = vertices.cache()
- }
- println("Done parsing input file.")
-
- // Do the computation
- val epsilon = 0.01 / numVertices
- val messages = sc.parallelize(Array[(String, PRMessage)]())
- val utils = new PageRankUtils
- val result =
- Bagel.run(
- sc, vertices, messages, combiner = new PRCombiner(),
- numPartitions = numPartitions)(
- utils.computeWithCombiner(numVertices, epsilon))
-
- // Print the result
- System.err.println("Articles with PageRank >= " + threshold + ":")
- val top =
- (result
- .filter { case (id, vertex) => vertex.value >= threshold }
- .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) }
- .collect().mkString)
- println(top)
-
- sc.stop()
- }
-}
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala
deleted file mode 100644
index 576a3e371b993..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala
+++ /dev/null
@@ -1,232 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.bagel
-
-import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-import scala.xml.{XML, NodeSeq}
-
-import org.apache.spark._
-import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.RDD
-
-import scala.reflect.ClassTag
-
-object WikipediaPageRankStandalone {
- def main(args: Array[String]) {
- if (args.length < 4) {
- System.err.println("Usage: WikipediaPageRankStandalone " +
- "")
- System.exit(-1)
- }
- val sparkConf = new SparkConf()
- sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer")
-
- val inputFile = args(0)
- val threshold = args(1).toDouble
- val numIterations = args(2).toInt
- val usePartitioner = args(3).toBoolean
-
- sparkConf.setAppName("WikipediaPageRankStandalone")
-
- val sc = new SparkContext(sparkConf)
-
- val input = sc.textFile(inputFile)
- val partitioner = new HashPartitioner(sc.defaultParallelism)
- val links =
- if (usePartitioner) {
- input.map(parseArticle _).partitionBy(partitioner).cache()
- } else {
- input.map(parseArticle _).cache()
- }
- val n = links.count()
- val defaultRank = 1.0 / n
- val a = 0.15
-
- // Do the computation
- val startTime = System.currentTimeMillis
- val ranks =
- pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner,
- sc.defaultParallelism)
-
- // Print the result
- System.err.println("Articles with PageRank >= " + threshold + ":")
- val top =
- (ranks
- .filter { case (id, rank) => rank >= threshold }
- .map { case (id, rank) => "%s\t%s\n".format(id, rank) }
- .collect().mkString)
- println(top)
-
- val time = (System.currentTimeMillis - startTime) / 1000.0
- println("Completed %d iterations in %f seconds: %f seconds per iteration"
- .format(numIterations, time, time / numIterations))
- sc.stop()
- }
-
- def parseArticle(line: String): (String, Array[String]) = {
- val fields = line.split("\t")
- val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
- val id = new String(title)
- val links =
- if (body == "\\N") {
- NodeSeq.Empty
- } else {
- try {
- XML.loadString(body) \\ "link" \ "target"
- } catch {
- case e: org.xml.sax.SAXParseException =>
- System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body)
- NodeSeq.Empty
- }
- }
- val outEdges = links.map(link => new String(link.text)).toArray
- (id, outEdges)
- }
-
- def pageRank(
- links: RDD[(String, Array[String])],
- numIterations: Int,
- defaultRank: Double,
- a: Double,
- n: Long,
- partitioner: Partitioner,
- usePartitioner: Boolean,
- numPartitions: Int
- ): RDD[(String, Double)] = {
- var ranks = links.mapValues { edges => defaultRank }
- for (i <- 1 to numIterations) {
- val contribs = links.groupWith(ranks).flatMap {
- case (id, (linksWrapperIterable, rankWrapperIterable)) =>
- val linksWrapper = linksWrapperIterable.iterator
- val rankWrapper = rankWrapperIterable.iterator
- if (linksWrapper.hasNext) {
- val linksWrapperHead = linksWrapper.next
- if (rankWrapper.hasNext) {
- val rankWrapperHead = rankWrapper.next
- linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size))
- } else {
- linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size))
- }
- } else {
- Array[(String, Double)]()
- }
- }
- ranks = (contribs.combineByKey((x: Double) => x,
- (x: Double, y: Double) => x + y,
- (x: Double, y: Double) => x + y,
- partitioner)
- .mapValues(sum => a/n + (1-a)*sum))
- }
- ranks
- }
-}
-
-class WPRSerializer extends org.apache.spark.serializer.Serializer {
- def newInstance(): SerializerInstance = new WPRSerializerInstance()
-}
-
-class WPRSerializerInstance extends SerializerInstance {
- def serialize[T: ClassTag](t: T): ByteBuffer = {
- throw new UnsupportedOperationException()
- }
-
- def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
- throw new UnsupportedOperationException()
- }
-
- def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
- throw new UnsupportedOperationException()
- }
-
- def serializeStream(s: OutputStream): SerializationStream = {
- new WPRSerializationStream(s)
- }
-
- def deserializeStream(s: InputStream): DeserializationStream = {
- new WPRDeserializationStream(s)
- }
-}
-
-class WPRSerializationStream(os: OutputStream) extends SerializationStream {
- val dos = new DataOutputStream(os)
-
- def writeObject[T: ClassTag](t: T): SerializationStream = t match {
- case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
- case links: Array[String] => {
- dos.writeInt(0) // links
- dos.writeUTF(id)
- dos.writeInt(links.length)
- for (link <- links) {
- dos.writeUTF(link)
- }
- this
- }
- case rank: Double => {
- dos.writeInt(1) // rank
- dos.writeUTF(id)
- dos.writeDouble(rank)
- this
- }
- }
- case (id: String, rank: Double) => {
- dos.writeInt(2) // rank without wrapper
- dos.writeUTF(id)
- dos.writeDouble(rank)
- this
- }
- }
-
- def flush() { dos.flush() }
- def close() { dos.close() }
-}
-
-class WPRDeserializationStream(is: InputStream) extends DeserializationStream {
- val dis = new DataInputStream(is)
-
- def readObject[T: ClassTag](): T = {
- val typeId = dis.readInt()
- typeId match {
- case 0 => {
- val id = dis.readUTF()
- val numLinks = dis.readInt()
- val links = new Array[String](numLinks)
- for (i <- 0 until numLinks) {
- val link = dis.readUTF()
- links(i) = link
- }
- (id, ArrayBuffer(links)).asInstanceOf[T]
- }
- case 1 => {
- val id = dis.readUTF()
- val rank = dis.readDouble()
- (id, ArrayBuffer(rank)).asInstanceOf[T]
- }
- case 2 => {
- val id = dis.readUTF()
- val rank = dis.readDouble()
- (id, rank).asInstanceOf[T]
- }
- }
- }
-
- def close() { dis.close() }
-}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
new file mode 100644
index 0000000000000..b54466fd48bc5
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
+import org.apache.spark.sql.DataFrame
+
+/**
+ * An example runner for linear regression with elastic-net (mixing L1/L2) regularization.
+ * Run with
+ * {{{
+ * bin/run-example ml.LinearRegressionExample [options]
+ * }}}
+ * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be
+ * trained by
+ * {{{
+ * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \
+ * data/mllib/sample_linear_regression_data.txt
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LinearRegressionExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ regParam: Double = 0.0,
+ elasticNetParam: Double = 0.0,
+ maxIter: Int = 100,
+ tol: Double = 1E-6,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LinearRegressionExample") {
+ head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.")
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Double]("elasticNetParam")
+ .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
+ s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
+ s"L1 and L2, default: ${defaultParams.elasticNetParam}")
+ .action((x, c) => c.copy(elasticNetParam = x))
+ opt[Int]("maxIter")
+ .text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("tol")
+ .text(s"the convergence tolerance of iterations, Smaller value will lead " +
+ s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
+ .action((x, c) => c.copy(tol = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"LinearRegressionExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, "regression", params.fracTest)
+
+ val lir = new LinearRegression()
+ .setFeaturesCol("features")
+ .setLabelCol("label")
+ .setRegParam(params.regParam)
+ .setElasticNetParam(params.elasticNetParam)
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+
+ // Train the model
+ val startTime = System.nanoTime()
+ val lirModel = lir.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Print the weights and intercept for linear regression.
+ println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}")
+
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label")
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label")
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
new file mode 100644
index 0000000000000..b12f833ce94c8
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.feature.StringIndexer
+import org.apache.spark.sql.DataFrame
+
+/**
+ * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization.
+ * Run with
+ * {{{
+ * bin/run-example ml.LogisticRegressionExample [options]
+ * }}}
+ * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be
+ * trained by
+ * {{{
+ * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \
+ * data/mllib/sample_libsvm_data.txt
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LogisticRegressionExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ regParam: Double = 0.0,
+ elasticNetParam: Double = 0.0,
+ maxIter: Int = 100,
+ fitIntercept: Boolean = true,
+ tol: Double = 1E-6,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LogisticRegressionExample") {
+ head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.")
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Double]("elasticNetParam")
+ .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
+ s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
+ s"L1 and L2, default: ${defaultParams.elasticNetParam}")
+ .action((x, c) => c.copy(elasticNetParam = x))
+ opt[Int]("maxIter")
+ .text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Boolean]("fitIntercept")
+ .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}")
+ .action((x, c) => c.copy(fitIntercept = x))
+ opt[Double]("tol")
+ .text(s"the convergence tolerance of iterations, Smaller value will lead " +
+ s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
+ .action((x, c) => c.copy(tol = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"LogisticRegressionExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, "classification", params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol("indexedLabel")
+ stages += labelIndexer
+
+ val lor = new LogisticRegression()
+ .setFeaturesCol("features")
+ .setLabelCol("indexedLabel")
+ .setRegParam(params.regParam)
+ .setElasticNetParam(params.elasticNetParam)
+ .setMaxIter(params.maxIter)
+ .setTol(params.tol)
+
+ stages += lor
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel]
+ // Print the weights and intercept for logistic regression.
+ println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}")
+
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel")
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel")
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
index b99d0a1246011..6927eb8f275cf 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -73,7 +73,7 @@ object OneVsRestExample {
.action((x, c) => c.copy(fracTest = x))
opt[String]("testInput")
.text("input path to test dataset. If given, option fracTest is ignored")
- .action((x,c) => c.copy(testInput = Some(x)))
+ .action((x, c) => c.copy(testInput = Some(x)))
opt[Int]("maxIter")
.text(s"maximum number of iterations for Logistic Regression." +
s" default: ${defaultParams.maxIter}")
@@ -88,10 +88,10 @@ object OneVsRestExample {
.action((x, c) => c.copy(fitIntercept = x))
opt[Double]("regParam")
.text(s"the regularization parameter for Logistic Regression.")
- .action((x,c) => c.copy(regParam = Some(x)))
+ .action((x, c) => c.copy(regParam = Some(x)))
opt[Double]("elasticNetParam")
.text(s"the ElasticNet mixing parameter for Logistic Regression.")
- .action((x,c) => c.copy(elasticNetParam = Some(x)))
+ .action((x, c) => c.copy(elasticNetParam = Some(x)))
checkConfig { params =>
if (params.fracTest < 0 || params.fracTest >= 1) {
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index e8a991f50e338..a0561e2573fc9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -87,7 +87,7 @@ object SimpleParamsExample {
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
// Make predictions on test data using the Transformer.transform() method.
- // LogisticRegression.transform will only use the 'features' column.
+ // LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test.toDF())
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b0613632c9946..3381941673db8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,7 +22,6 @@ import scala.language.reflectiveCalls
import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -354,7 +353,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
+ *
+ * This is just for demo purpose. In general, don't copy this code because it is NOT efficient
+ * due to the use of structural types, which leads to one reflection call per record.
*/
+ // scalastyle:off structural.type
private[mllib] def meanSquaredError(
model: { def predict(features: Vector): Double },
data: RDD[LabeledPoint]): Double = {
@@ -363,4 +366,5 @@ object DecisionTreeRunner {
err * err
}.mean()
}
+ // scalastyle:on structural.type
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
index df76b45e50810..f8c71ccabc43b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala
@@ -40,23 +40,23 @@ object DenseGaussianMixture {
private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) {
val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
- val ctx = new SparkContext(conf)
-
+ val ctx = new SparkContext(conf)
+
val data = ctx.textFile(inputFile).map { line =>
Vectors.dense(line.trim.split(' ').map(_.toDouble))
}.cache()
-
+
val clusters = new GaussianMixture()
.setK(k)
.setConvergenceTol(convergenceTol)
.setMaxIterations(maxIterations)
.run(data)
-
+
for (i <- 0 until clusters.k) {
- println("weight=%f\nmu=%s\nsigma=\n%s\n" format
+ println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
-
+
println("Cluster labels (first <= 100):")
val clusterLabels = clusters.predict(data)
clusterLabels.take(100).foreach { x =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala
index a11890d6f2b1c..3ebb112fc069e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala
@@ -36,22 +36,21 @@ object AvroConversionUtil extends Serializable {
return null
}
schema.getType match {
- case UNION => unpackUnion(obj, schema)
- case ARRAY => unpackArray(obj, schema)
- case FIXED => unpackFixed(obj, schema)
- case MAP => unpackMap(obj, schema)
- case BYTES => unpackBytes(obj)
- case RECORD => unpackRecord(obj)
- case STRING => obj.toString
- case ENUM => obj.toString
- case NULL => obj
+ case UNION => unpackUnion(obj, schema)
+ case ARRAY => unpackArray(obj, schema)
+ case FIXED => unpackFixed(obj, schema)
+ case MAP => unpackMap(obj, schema)
+ case BYTES => unpackBytes(obj)
+ case RECORD => unpackRecord(obj)
+ case STRING => obj.toString
+ case ENUM => obj.toString
+ case NULL => obj
case BOOLEAN => obj
- case DOUBLE => obj
- case FLOAT => obj
- case INT => obj
- case LONG => obj
- case other => throw new SparkException(
- s"Unknown Avro schema type ${other.getName}")
+ case DOUBLE => obj
+ case FLOAT => obj
+ case INT => obj
+ case LONG => obj
+ case other => throw new SparkException(s"Unknown Avro schema type ${other.getName}")
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala
index 92867b44be138..016de4c63d1d2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala
@@ -104,10 +104,8 @@ extends Actor with ActorHelper {
object FeederActor {
def main(args: Array[String]) {
- if(args.length < 2){
- System.err.println(
- "Usage: FeederActor \n"
- )
+ if (args.length < 2){
+ System.err.println("Usage: FeederActor \n")
System.exit(1)
}
val Seq(host, port) = args.toSeq
diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
similarity index 97%
rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
rename to examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
index 11a8cf09533ce..fbe394de4a179 100644
--- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
@@ -51,7 +51,7 @@ object DirectKafkaWordCount {
// Create context with 2 second batch interval
val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount")
- val ssc = new StreamingContext(sparkConf, Seconds(2))
+ val ssc = new StreamingContext(sparkConf, Seconds(2))
// Create direct kafka stream with brokers and topics
val topicsSet = topics.split(",").toSet
diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
similarity index 95%
rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
rename to examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
index f407367a54f6c..60416ee343544 100644
--- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
@@ -49,10 +49,10 @@ object KafkaWordCount {
val Array(zkQuorum, group, topics, numThreads) = args
val sparkConf = new SparkConf().setAppName("KafkaWordCount")
- val ssc = new StreamingContext(sparkConf, Seconds(2))
+ val ssc = new StreamingContext(sparkConf, Seconds(2))
ssc.checkpoint("checkpoint")
- val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap
+ val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap
val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1L))
@@ -96,7 +96,7 @@ object KafkaWordCountProducer {
producer.send(message)
}
- Thread.sleep(100)
+ Thread.sleep(1000)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
index 85b9a54b40baf..813c8554f5193 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
@@ -40,7 +40,7 @@ object MQTTPublisher {
StreamingExamples.setStreamingLogLevels()
val Seq(brokerUrl, topic) = args.toSeq
-
+
var client: MqttClient = null
try {
@@ -49,7 +49,7 @@ object MQTTPublisher {
client.connect()
- val msgtopic = client.getTopic(topic)
+ val msgtopic = client.getTopic(topic)
val msgContent = "hello mqtt demo for spark streaming"
val message = new MqttMessage(msgContent.getBytes("utf-8"))
@@ -59,10 +59,10 @@ object MQTTPublisher {
println(s"Published data. topic: ${msgtopic.getName()}; Message: $message")
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
- Thread.sleep(10)
+ Thread.sleep(10)
println("Queue is full, wait for to consume data from the message queue")
- }
- }
+ }
+ }
} catch {
case e: MqttException => println("Exception Caught: " + e)
} finally {
@@ -107,7 +107,7 @@ object MQTTWordCount {
val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2)
val words = lines.flatMap(x => x.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
-
+
wordCounts.print()
ssc.start()
ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
index 54d996b8ac990..889f052c70263 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
@@ -57,8 +57,7 @@ object PageViewGenerator {
404 -> .05)
val userZipCode = Map(94709 -> .5,
94117 -> .5)
- val userID = Map((1 to 100).map(_ -> .01):_*)
-
+ val userID = Map((1 to 100).map(_ -> .01) : _*)
def pickFromDistribution[T](inputMap : Map[T, Double]) : T = {
val rand = new Random().nextDouble()
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 1f3e619d97a24..7a7dccc3d0922 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -42,15 +42,46 @@
org.apache.flumeflume-ng-sdk
+
+
+
+ com.google.guava
+ guava
+
+
+
+ org.apache.thrift
+ libthrift
+
+ org.apache.flumeflume-ng-core
+
+
+ com.google.guava
+ guava
+
+
+ org.apache.thrift
+ libthrift
+
+ org.scala-langscala-library
+
+
+ com.google.guava
+ guava
+ test
+
+
+
+
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
index fd01807fc3ac4..dc2a4ab138e18 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
@@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
-import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.flume.Channel
import org.apache.commons.lang3.RandomStringUtils
@@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils
private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel,
val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging {
val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads,
- new ThreadFactoryBuilder().setDaemon(true)
- .setNameFormat("Spark Sink Processor Thread - %d").build()))
+ new SparkSinkThreadFactory("Spark Sink Processor Thread - %d")))
// Protected by `sequenceNumberToProcessor`
private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]()
// This sink will not persist sequence numbers and reuses them if it gets restarted.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
similarity index 61%
rename from core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
rename to external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
index d75959f480756..845fc8debda75 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala
@@ -14,11 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.streaming.flume.sink
-package org.apache.spark.util.collection
+import java.util.concurrent.ThreadFactory
+import java.util.concurrent.atomic.AtomicLong
-private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
- def hasNext: Boolean = iter.hasNext
+/**
+ * Thread factory that generates daemon threads with a specified name format.
+ */
+private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory {
+
+ private val threadId = new AtomicLong()
+
+ override def newThread(r: Runnable): Thread = {
+ val t = new Thread(r, nameFormat.format(threadId.incrementAndGet()))
+ t.setDaemon(true)
+ t
+ }
- def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
}
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
index ea45b14294df9..7ad43b1d7b0a0 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
@@ -143,7 +143,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
eventBatch.setErrorMsg(msg)
} else {
// At this point, the events are available, so fill them into the event batch
- eventBatch = new EventBatch("",seqNum, events)
+ eventBatch = new EventBatch("", seqNum, events)
}
})
} catch {
diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
index 650b2fbe1c142..fa43629d49771 100644
--- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
+++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
@@ -24,16 +24,24 @@ import scala.collection.JavaConversions._
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
-import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.flume.Context
import org.apache.flume.channel.MemoryChannel
import org.apache.flume.event.EventBuilder
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
+
+// Due to MNG-1378, there is not a way to include test dependencies transitively.
+// We cannot include Spark core tests as a dependency here because it depends on
+// Spark core main, which has too many dependencies to require here manually.
+// For this reason, we continue to use FunSuite and ignore the scalastyle checks
+// that fail if this is detected.
+//scalastyle:off
import org.scalatest.FunSuite
class SparkSinkSuite extends FunSuite {
+//scalastyle:on
+
val eventsPerBatch = 1000
val channelCapacity = 5000
@@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite {
count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = {
(1 to count).map(_ => {
- lazy val channelFactoryExecutor =
- Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true).
- setNameFormat("Flume Receiver Channel Thread - %d").build())
+ lazy val channelFactoryExecutor = Executors.newCachedThreadPool(
+ new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d"))
lazy val channelFactory =
new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor)
val transceiver = new NettyTransceiver(address, channelFactory)
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 8df7edbdcad33..14f7daaf417e0 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming-flume-sink_${scala.binary.version}
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala
index dc629df4f4ac2..65c49c131518b 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala
@@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging {
out.write(body)
val numHeaders = headers.size()
out.writeInt(numHeaders)
- for ((k,v) <- headers) {
+ for ((k, v) <- headers) {
val keyBuff = Utils.serialize(k.toString)
out.writeInt(keyBuff.length)
out.write(keyBuff)
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 60e2994431b38..1e32a365a1eee 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -152,9 +152,9 @@ class FlumeReceiver(
val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(),
Executors.newCachedThreadPool())
val channelPipelineFactory = new CompressionChannelPipelineFactory()
-
+
new NettyServer(
- responder,
+ responder,
new InetSocketAddress(host, port),
channelFactory,
channelPipelineFactory,
@@ -188,12 +188,12 @@ class FlumeReceiver(
override def preferredLocation: Option[String] = Option(host)
- /** A Netty Pipeline factory that will decompress incoming data from
+ /** A Netty Pipeline factory that will decompress incoming data from
* and the Netty client and compress data going back to the client.
*
* The compression on the return is required because Flume requires
- * a successful response to indicate it can remove the event/batch
- * from the configured channel
+ * a successful response to indicate it can remove the event/batch
+ * from the configured channel
*/
private[streaming]
class CompressionChannelPipelineFactory extends ChannelPipelineFactory {
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
index 92fa5b41be89e..583e7dca317ad 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
@@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver(
}
/**
- * A wrapper around the transceiver and the Avro IPC API.
+ * A wrapper around the transceiver and the Avro IPC API.
* @param transceiver The transceiver to use for communication with Flume
* @param client The client that the callbacks are received on.
*/
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 93afe50c2134f..d772b9ca9b570 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -31,16 +31,16 @@ import org.apache.flume.conf.Configurables
import org.apache.flume.event.EventBuilder
import org.scalatest.concurrent.Eventually._
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
import org.apache.spark.streaming.flume.sink._
import org.apache.spark.util.{ManualClock, Utils}
-class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging {
+class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val batchCount = 5
val eventsPerBatch = 100
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 39e6754c81dbf..c926359987d89 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -35,15 +35,15 @@ import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
import org.apache.spark.util.Utils
-class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
+class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
var ssc: StreamingContext = null
@@ -138,7 +138,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L
val status = client.appendBatch(inputEvents.toList)
status should be (avro.Status.OK)
}
-
+
eventually(timeout(10 seconds), interval(100 milliseconds)) {
val outputEvents = outputBuffer.flatten.map { _.event }
outputEvents.foreach {
diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml
index 0b79f47647f6b..8059c443827ef 100644
--- a/external/kafka-assembly/pom.xml
+++ b/external/kafka-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 243ce6eaca658..ded863bd985e8 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.kafkakafka_${scala.binary.version}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index 6cf254a7b69cb..65d51d87f8486 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
r.flatMap { tm: TopicMetadata =>
tm.partitionsMetadata.map { pm: PartitionMetadata =>
TopicAndPartition(tm.topic, pm.partitionId)
- }
+ }
}
}
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index 6dc4e9517d5a4..b608b75952721 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging {
val props = new Properties()
props.put("metadata.broker.list", brokerAddress)
props.put("serializer.class", classOf[StringEncoder].getName)
+ // wait for all in-sync replicas to ack sends
+ props.put("request.required.acks", "-1")
props
}
@@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging {
tryAgain(1)
}
- /** Wait until the leader offset for the given topic/partition equals the specified offset */
- def waitUntilLeaderOffset(
- topic: String,
- partition: Int,
- offset: Long): Unit = {
- eventually(Time(10000), Time(100)) {
- val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress))
- val tp = TopicAndPartition(topic, partition)
- val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset
- assert(
- llo == offset,
- s"$topic $partition $offset not reached after timeout")
- }
- }
-
private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match {
case Some(partitionState) =>
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 8be2707528d93..0b8a391a2c569 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -315,7 +315,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -363,7 +363,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -427,7 +427,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
@@ -489,7 +489,7 @@ object KafkaUtils {
* Points to note:
* - No receivers: This stream does not use any receiver. It directly queries Kafka
* - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
- * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
* Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
* You can access the offsets used in each batch from the generated RDDs (see
* [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
index 5cf379635354f..a9dc6e50613ca 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -72,9 +72,6 @@ public void testKafkaRDD() throws InterruptedException {
HashMap kafkaParams = new HashMap();
kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress());
- kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length);
- kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length);
-
OffsetRange[] offsetRanges = {
OffsetRange.create(topic1, 0, 0, 1),
OffsetRange.create(topic2, 0, 0, 1)
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index b6d314dfc7783..47bbfb605850a 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -28,10 +28,10 @@ import scala.language.postfixOps
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import kafka.serializer.StringDecoder
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.dstream.DStream
@@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.Utils
class DirectKafkaStreamSuite
- extends FunSuite
+ extends SparkFunSuite
with BeforeAndAfter
with BeforeAndAfterAll
with Eventually
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
index 7fb841b79cb65..d66830cbacdee 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka
import scala.util.Random
import kafka.common.TopicAndPartition
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
-class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll {
+import org.apache.spark.SparkFunSuite
+
+class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll {
private val topic = "kcsuitetopic" + Random.nextInt(10000)
private val topicAndPartition = TopicAndPartition(topic, 0)
private var kc: KafkaCluster = null
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 39c3fb448ff57..d5baf5fd89994 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -22,11 +22,11 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark._
-class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
+class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
private var kafkaTestUtils: KafkaTestUtils = _
@@ -61,11 +61,9 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"group.id" -> s"test-consumer-${Random.nextInt}")
- kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size)
-
val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
- val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
+ val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
sc, kafkaParams, offsetRanges)
val received = rdd.map(_._2).collect.toSet
@@ -86,7 +84,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
// this is the "lots of messages" case
kafkaTestUtils.sendMessages(topic, sent)
val sentCount = sent.values.sum
- kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount)
// rdd defined from leaders after sending messages, should get the number sent
val rdd = getRdd(kc, Set(topic))
@@ -113,7 +110,6 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
val sentOnlyOne = Map("d" -> 1)
kafkaTestUtils.sendMessages(topic, sentOnlyOne)
- kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1)
assert(rdd2.isDefined)
assert(rdd2.get.count === 0, "got messages when there shouldn't be any")
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index 24699dfc33adb..8ee2cc660f849 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -23,14 +23,14 @@ import scala.language.postfixOps
import scala.util.Random
import kafka.serializer.StringDecoder
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll {
+class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
private var ssc: StreamingContext = _
private var kafkaTestUtils: KafkaTestUtils = _
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index 38548dd73b82c..80e2df62de3fe 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -26,15 +26,15 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.utils.{ZKGroupTopicDirs, ZkUtils}
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.util.Utils
-class ReliableKafkaStreamSuite extends FunSuite
+class ReliableKafkaStreamSuite extends SparkFunSuite
with BeforeAndAfterAll with BeforeAndAfter with Eventually {
private val sparkConf = new SparkConf()
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 98f95a9a64fa0..0e41e5781784b 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.eclipse.pahoorg.eclipse.paho.client.mqttv3
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
index 40f5f18547236..7c2f18cb35bda 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
@@ -17,22 +17,10 @@
package org.apache.spark.streaming.mqtt
-import java.io.IOException
-import java.util.concurrent.Executors
-import java.util.Properties
-
-import scala.collection.JavaConversions._
-import scala.collection.Map
-import scala.collection.mutable.HashMap
-import scala.reflect.ClassTag
-
import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
import org.eclipse.paho.client.mqttv3.MqttCallback
import org.eclipse.paho.client.mqttv3.MqttClient
-import org.eclipse.paho.client.mqttv3.MqttClientPersistence
-import org.eclipse.paho.client.mqttv3.MqttException
import org.eclipse.paho.client.mqttv3.MqttMessage
-import org.eclipse.paho.client.mqttv3.MqttTopic
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
import org.apache.spark.storage.StorageLevel
@@ -87,7 +75,7 @@ class MQTTReceiver(
// Handles Mqtt message
override def messageArrived(topic: String, message: MqttMessage) {
- store(new String(message.getPayload(),"utf-8"))
+ store(new String(message.getPayload(), "utf-8"))
}
override def deliveryComplete(token: IMqttDeliveryToken) {
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index a19a72c58a705..c4bf5aa7869bb 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
@@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils
-class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
+class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
private val batchDuration = Milliseconds(500)
private val master = "local[2]"
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 8b6a8959ac4cf..178ae8de13b57 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.twitter4jtwitter4j-stream
diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
index 9ee57d7581d85..d9acb568879fe 100644
--- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
+++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
@@ -18,16 +18,16 @@
package org.apache.spark.streaming.twitter
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.BeforeAndAfter
import twitter4j.Status
import twitter4j.auth.{NullAuthorization, Authorization}
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging {
+class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {
val batchDuration = Seconds(1)
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index a50d378b34335..37bfd10d43663 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -41,6 +41,13 @@
${project.version}provided
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ ${akka.group}akka-zeromq_${scala.binary.version}
diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
index a7566e733d891..35d2e62c68480 100644
--- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
+++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq
import akka.actor.SupervisorStrategy
import akka.util.ByteString
import akka.zeromq.Subscribe
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class ZeroMQStreamSuite extends FunSuite {
+class ZeroMQStreamSuite extends SparkFunSuite {
val batchDuration = Seconds(1)
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index 4351a8a12fe21..f138251748c9e 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index 25847a1b33d9c..c6f60bc907438 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -40,6 +40,13 @@
spark-streaming_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming_${scala.binary.version}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
index df77f4be9db1d..be8b62d3cc6ba 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
@@ -119,7 +119,7 @@ object KinesisWordCountASL extends Logging {
val batchInterval = Milliseconds(2000)
// Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information
- // on sequence number of records that have been received. Same as batchInterval for this
+ // on sequence number of records that have been received. Same as batchInterval for this
// example.
val kinesisCheckpointInterval = batchInterval
@@ -145,7 +145,7 @@ object KinesisWordCountASL extends Logging {
// Map each word to a (word, 1) tuple so we can reduce by key to count the words
val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _)
-
+
// Print the first 10 wordCounts
wordCounts.print()
@@ -208,16 +208,16 @@ object KinesisWordProducerASL {
recordsPerSecond: Int,
wordsPerRecord: Int): Seq[(String, Int)] = {
- val randomWords = List("spark","you","are","my","father")
+ val randomWords = List("spark", "you", "are", "my", "father")
val totals = scala.collection.mutable.Map[String, Int]()
-
+
// Create the low-level Kinesis Client from the AWS Java SDK.
val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain())
kinesisClient.setEndpoint(endpoint)
println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" +
s" $recordsPerSecond records per second and $wordsPerRecord words per record")
-
+
// Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord
for (i <- 1 to 10) {
// Generate recordsPerSec records to put onto the stream
@@ -255,8 +255,8 @@ object KinesisWordProducerASL {
}
}
-/**
- * Utility functions for Spark Streaming examples.
+/**
+ * Utility functions for Spark Streaming examples.
* This has been lifted from the examples/ project to remove the circular dependency.
*/
private[streaming] object StreamingExamples extends Logging {
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
index 1c9b0c218ae18..83a4537559512 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
@@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock}
/**
* This is a helper class for managing checkpoint clocks.
*
- * @param checkpointInterval
+ * @param checkpointInterval
* @param currentClock. Default to current SystemClock if none is passed in (mocking purposes)
*/
private[kinesis] class KinesisCheckpointState(
- checkpointInterval: Duration,
+ checkpointInterval: Duration,
currentClock: Clock = new SystemClock())
extends Logging {
-
+
/* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */
val checkpointClock = new ManualClock()
checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds)
/**
- * Check if it's time to checkpoint based on the current time and the derived time
+ * Check if it's time to checkpoint based on the current time and the derived time
* for the next checkpoint
*
* @return true if it's time to checkpoint
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 7dd8bfdc2a6db..1a8a4cecc1141 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -44,12 +44,12 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* https://github.com/awslabs/amazon-kinesis-client
* This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here:
* http://spark.apache.org/docs/latest/streaming-custom-receivers.html
- * Instances of this class will get shipped to the Spark Streaming Workers to run within a
+ * Instances of this class will get shipped to the Spark Streaming Workers to run within a
* Spark Executor.
*
* @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams
* by the Kinesis Client Library. If you change the App name or Stream name,
- * the KCL will throw errors. This usually requires deleting the backing
+ * the KCL will throw errors. This usually requires deleting the backing
* DynamoDB table with the same name this Kinesis application.
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
@@ -87,7 +87,7 @@ private[kinesis] class KinesisReceiver(
*/
/**
- * workerId is used by the KCL should be based on the ip address of the actual Spark Worker
+ * workerId is used by the KCL should be based on the ip address of the actual Spark Worker
* where this code runs (not the driver's IP address.)
*/
private var workerId: String = null
@@ -121,7 +121,7 @@ private[kinesis] class KinesisReceiver(
/*
* RecordProcessorFactory creates impls of IRecordProcessor.
- * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the
+ * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the
* IRecordProcessor.processRecords() method.
* We're using our custom KinesisRecordProcessor in this case.
*/
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index f65e743c4e2a3..fe9e3a0c793e2 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -35,9 +35,9 @@ import com.amazonaws.services.kinesis.model.Record
/**
* Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor.
* This implementation operates on the Array[Byte] from the KinesisReceiver.
- * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each
- * shard in the Kinesis stream upon startup. This is normally done in separate threads,
- * but the KCLs within the KinesisReceivers will balance themselves out if you create
+ * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each
+ * shard in the Kinesis stream upon startup. This is normally done in separate threads,
+ * but the KCLs within the KinesisReceivers will balance themselves out if you create
* multiple Receivers.
*
* @param receiver Kinesis receiver
@@ -69,14 +69,14 @@ private[kinesis] class KinesisRecordProcessor(
* and Spark Streaming's Receiver.store().
*
* @param batch list of records from the Kinesis stream shard
- * @param checkpointer used to update Kinesis when this batch has been processed/stored
+ * @param checkpointer used to update Kinesis when this batch has been processed/stored
* in the DStream
*/
override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) {
if (!receiver.isStopped()) {
try {
/*
- * Notes:
+ * Notes:
* 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming
* Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the
* internally-configured Spark serializer (kryo, etc).
@@ -84,19 +84,19 @@ private[kinesis] class KinesisRecordProcessor(
* ourselves from Spark's internal serialization strategy.
* 3) For performance, the BlockGenerator is asynchronously queuing elements within its
* memory before creating blocks. This prevents the small block scenario, but requires
- * that you register callbacks to know when a block has been generated and stored
+ * that you register callbacks to know when a block has been generated and stored
* (WAL is sufficient for storage) before can checkpoint back to the source.
*/
batch.foreach(record => receiver.store(record.getData().array()))
-
+
logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId")
/*
- * Checkpoint the sequence number of the last record successfully processed/stored
+ * Checkpoint the sequence number of the last record successfully processed/stored
* in the batch.
* In this implementation, we're checkpointing after the given checkpointIntervalMillis.
- * Note that this logic requires that processRecords() be called AND that it's time to
- * checkpoint. I point this out because there is no background thread running the
+ * Note that this logic requires that processRecords() be called AND that it's time to
+ * checkpoint. I point this out because there is no background thread running the
* checkpointer. Checkpointing is tested and trigger only when a new batch comes in.
* If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below).
* However, if the worker dies unexpectedly, a checkpoint may not happen.
@@ -130,16 +130,16 @@ private[kinesis] class KinesisRecordProcessor(
}
} else {
/* RecordProcessor has been stopped. */
- logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" +
+ logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" +
s" and shardId $shardId. No more records will be processed.")
}
}
/**
* Kinesis Client Library is shutting down this Worker for 1 of 2 reasons:
- * 1) the stream is resharding by splitting or merging adjacent shards
+ * 1) the stream is resharding by splitting or merging adjacent shards
* (ShutdownReason.TERMINATE)
- * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason
+ * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason
* (ShutdownReason.ZOMBIE)
*
* @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE
@@ -153,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor(
* Checkpoint to indicate that all records from the shard have been drained and processed.
* It's now OK to read from the new shards that resulted from a resharding event.
*/
- case ShutdownReason.TERMINATE =>
+ case ShutdownReason.TERMINATE =>
KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
/*
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
index 2531aebe7813c..e5acab50181e1 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
@@ -55,7 +55,7 @@ object KinesisUtils {
*/
def createStream(
ssc: StreamingContext,
- kinesisAppName: String,
+ kinesisAppName: String,
streamName: String,
endpointUrl: String,
regionName: String,
@@ -102,7 +102,7 @@ object KinesisUtils {
*/
def createStream(
ssc: StreamingContext,
- kinesisAppName: String,
+ kinesisAppName: String,
streamName: String,
endpointUrl: String,
regionName: String,
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index e14bbae4a9b6e..478d0019a25f0 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index d38a3aa8256b7..853dea9a7795e 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -40,6 +40,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ com.google.guavaguava
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
index 058c8c8aa1b24..ce1054ed92ba1 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
@@ -26,8 +26,8 @@ class EdgeDirection private (private val name: String) extends Serializable {
* out becomes in and both and either remain the same.
*/
def reverse: EdgeDirection = this match {
- case EdgeDirection.In => EdgeDirection.Out
- case EdgeDirection.Out => EdgeDirection.In
+ case EdgeDirection.In => EdgeDirection.Out
+ case EdgeDirection.Out => EdgeDirection.In
case EdgeDirection.Either => EdgeDirection.Either
case EdgeDirection.Both => EdgeDirection.Both
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index cc70b396a8dd4..4611a3ace219b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -41,14 +41,16 @@ abstract class EdgeRDD[ED](
@transient sc: SparkContext,
@transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) {
+ // scalastyle:off structural.type
private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD }
+ // scalastyle:on structural.type
override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = {
val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context)
if (p.hasNext) {
- p.next._2.iterator.map(_.copy())
+ p.next()._2.iterator.map(_.copy())
} else {
Iterator.empty
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
index c8790cac3d8a0..65f82429d2029 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
@@ -37,7 +37,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] {
/**
* Set the edge properties of this triplet.
*/
- protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = {
+ protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = {
srcId = other.srcId
dstId = other.dstId
attr = other.attr
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index 36dc7b0f86c89..db73a8abc5733 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -316,7 +316,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* satisfy the predicates
*/
def subgraph(
- epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
+ epred: EdgeTriplet[VD, ED] => Boolean = (x => true),
vpred: (VertexId, VD) => Boolean = ((v, d) => true))
: Graph[VD, ED]
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 7edd627b20918..9451ff1e5c0e2 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -124,18 +124,18 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
val nbrs = edgeDirection match {
case EdgeDirection.Either =>
- graph.aggregateMessages[Array[(VertexId,VD)]](
+ graph.aggregateMessages[Array[(VertexId, VD)]](
ctx => {
ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
},
(a, b) => a ++ b, TripletFields.All)
case EdgeDirection.In =>
- graph.aggregateMessages[Array[(VertexId,VD)]](
+ graph.aggregateMessages[Array[(VertexId, VD)]](
ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
(a, b) => a ++ b, TripletFields.Src)
case EdgeDirection.Out =>
- graph.aggregateMessages[Array[(VertexId,VD)]](
+ graph.aggregateMessages[Array[(VertexId, VD)]](
ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
(a, b) => a ++ b, TripletFields.Dst)
case EdgeDirection.Both =>
@@ -253,7 +253,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
def filter[VD2: ClassTag, ED2: ClassTag](
preprocess: Graph[VD, ED] => Graph[VD2, ED2],
epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true,
- vpred: (VertexId, VD2) => Boolean = (v:VertexId, d:VD2) => true): Graph[VD, ED] = {
+ vpred: (VertexId, VD2) => Boolean = (v: VertexId, d: VD2) => true): Graph[VD, ED] = {
graph.mask(preprocess(graph).subgraph(epred, vpred))
}
@@ -356,7 +356,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either)(
vprog: (VertexId, VD, A) => VD,
- sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)],
+ sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 01b013ff716fc..cfcf7244eaed5 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -147,10 +147,10 @@ object Pregel extends Logging {
logInfo("Pregel finished iteration " + i)
// Unpersist the RDDs hidden by newly-materialized RDDs
- oldMessages.unpersist(blocking=false)
- newVerts.unpersist(blocking=false)
- prevG.unpersistVertices(blocking=false)
- prevG.edges.unpersist(blocking=false)
+ oldMessages.unpersist(blocking = false)
+ newVerts.unpersist(blocking = false)
+ prevG.unpersistVertices(blocking = false)
+ prevG.edges.unpersist(blocking = false)
// count the iteration
i += 1
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
index c561570809253..ab021a252eb8a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -156,8 +156,8 @@ class EdgePartition[
val size = data.size
var i = 0
while (i < size) {
- edge.srcId = srcIds(i)
- edge.dstId = dstIds(i)
+ edge.srcId = srcIds(i)
+ edge.dstId = dstIds(i)
edge.attr = data(i)
newData(i) = f(edge)
i += 1
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index bc974b2f04e70..8c0a461e99fa4 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -116,7 +116,7 @@ object PageRank extends Logging {
val personalized = srcId isDefined
val src: VertexId = srcId.getOrElse(-1L)
- def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 }
+ def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 }
var iteration = 0
var prevRankGraph: Graph[Double, Double] = null
@@ -133,13 +133,13 @@ object PageRank extends Logging {
// edge partitions.
prevRankGraph = rankGraph
val rPrb = if (personalized) {
- (src: VertexId ,id: VertexId) => resetProb * delta(src,id)
+ (src: VertexId , id: VertexId) => resetProb * delta(src, id)
} else {
(src: VertexId, id: VertexId) => resetProb
}
rankGraph = rankGraph.joinVertices(rankUpdates) {
- (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum
+ (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum
}.cache()
rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices
@@ -243,7 +243,7 @@ object PageRank extends Logging {
// Execute a dynamic version of Pregel.
val vp = if (personalized) {
- (id: VertexId, attr: (Double, Double),msgSum: Double) =>
+ (id: VertexId, attr: (Double, Double), msgSum: Double) =>
personalizedVertexProgram(id, attr, msgSum)
} else {
(id: VertexId, attr: (Double, Double), msgSum: Double) =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index 3b0e1628d86b5..9cb24ed080e1c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -210,7 +210,7 @@ object SVDPlusPlus {
/**
* Forces materialization of a Graph by count()ing its RDDs.
*/
- private def materialize(g: Graph[_,_]): Unit = {
+ private def materialize(g: Graph[_, _]): Unit = {
g.vertices.count()
g.edges.count()
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
index daf162085e3e4..a5d598053f9ca 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -38,7 +38,7 @@ import org.apache.spark.graphx._
*/
object TriangleCount {
- def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = {
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = {
// Remove redundant edges
val g = graph.groupEdges((a, b) => a).cache()
@@ -49,7 +49,7 @@ object TriangleCount {
var i = 0
while (i < nbrs.size) {
// prevent self cycle
- if(nbrs(i) != vid) {
+ if (nbrs(i) != vid) {
set.add(nbrs(i))
}
i += 1
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
index 2d6a825b61726..9591c4e9b8f4e 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
@@ -243,14 +243,15 @@ object GraphGenerators {
* @return A graph containing vertices with the row and column ids
* as their attributes and edge values as 1.0.
*/
- def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = {
+ def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int, Int), Double] = {
// Convert row column address into vertex ids (row major order)
def sub2ind(r: Int, c: Int): VertexId = r * cols + c
- val vertices: RDD[(VertexId, (Int,Int))] =
- sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) )
+ val vertices: RDD[(VertexId, (Int, Int))] = sc.parallelize(0 until rows).flatMap { r =>
+ (0 until cols).map( c => (sub2ind(r, c), (r, c)) )
+ }
val edges: RDD[Edge[Double]] =
- vertices.flatMap{ case (vid, (r,c)) =>
+ vertices.flatMap{ case (vid, (r, c)) =>
(if (r + 1 < rows) { Seq( (sub2ind(r, c), sub2ind(r + 1, c))) } else { Seq.empty }) ++
(if (c + 1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c + 1))) } else { Seq.empty })
}.map{ case (src, dst) => Edge(src, dst, 1.0) }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
index eb1dbe52c2fda..f1ecc9e2219d1 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
-class EdgeRDDSuite extends FunSuite with LocalSparkContext {
+class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext {
test("cache, getStorageLevel") {
// test to see if getStorageLevel returns correct value after caching
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
index 5a2c73b414279..094a63472eaab 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala
@@ -17,21 +17,21 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class EdgeSuite extends FunSuite {
+class EdgeSuite extends SparkFunSuite {
test ("compare") {
// decending order
val testEdges: Array[Edge[Int]] = Array(
- Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1),
- Edge(0x2345L, 0x1234L, 1),
- Edge(0x1234L, 0x5678L, 1),
- Edge(0x1234L, 0x2345L, 1),
+ Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1),
+ Edge(0x2345L, 0x1234L, 1),
+ Edge(0x1234L, 0x5678L, 1),
+ Edge(0x1234L, 0x2345L, 1),
Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1)
)
// to ascending order
val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int])
-
+
for (i <- 0 until testEdges.length) {
assert(sortedEdges(i) == testEdges(testEdges.length - i - 1))
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index 9bc8007ce49cd..57a8b95dd12e9 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.graphx
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.impl.EdgePartition
import org.apache.spark.rdd._
-import org.scalatest.FunSuite
-class GraphOpsSuite extends FunSuite with LocalSparkContext {
+class GraphOpsSuite extends SparkFunSuite with LocalSparkContext {
test("joinVertices") {
withSpark { sc =>
@@ -59,7 +58,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
test ("filter") {
withSpark { sc =>
val n = 5
- val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x)))
+ val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x)))
val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
val graph: Graph[Int, Int] = Graph(vertices, edges).cache()
val filteredGraph = graph.filter(
@@ -67,11 +66,11 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
val degrees: VertexRDD[Int] = graph.outDegrees
graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)}
},
- vpred = (vid: VertexId, deg:Int) => deg > 0
+ vpred = (vid: VertexId, deg: Int) => deg > 0
).cache()
val v = filteredGraph.vertices.collect().toSet
- assert(v === Set((0,0)))
+ assert(v === Set((0, 0)))
// the map is necessary because of object-reuse in the edge iterator
val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index a570e4ed75fc3..1f5e27d5508b8 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.PartitionStrategy._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class GraphSuite extends FunSuite with LocalSparkContext {
+class GraphSuite extends SparkFunSuite with LocalSparkContext {
def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = {
Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v")
@@ -248,7 +246,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
test("mask") {
withSpark { sc =>
val n = 5
- val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x)))
+ val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x)))
val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
val graph: Graph[Int, Int] = Graph(vertices, edges).cache()
@@ -260,11 +258,11 @@ class GraphSuite extends FunSuite with LocalSparkContext {
val projectedGraph = graph.mask(subgraph)
val v = projectedGraph.vertices.collect().toSet
- assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5)))
+ assert(v === Set((0, 0), (1, 1), (2, 2), (4, 4), (5, 5)))
// the map is necessary because of object-reuse in the edge iterator
val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
- assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5)))
+ assert(e === Set(Edge(0, 1, 1), Edge(0, 2, 2), Edge(0, 5, 5)))
}
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
index 490b94429ea1f..8afa2d403b53f 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.rdd._
-class PregelSuite extends FunSuite with LocalSparkContext {
+class PregelSuite extends SparkFunSuite with LocalSparkContext {
test("1 iteration") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
index d0a7198d691d7..f1aa685a79c98 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
@@ -17,13 +17,11 @@
package org.apache.spark.graphx
-import org.scalatest.FunSuite
-
-import org.apache.spark.{HashPartitioner, SparkContext}
+import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-class VertexRDDSuite extends FunSuite with LocalSparkContext {
+class VertexRDDSuite extends SparkFunSuite with LocalSparkContext {
private def vertices(sc: SparkContext, n: Int) = {
VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5))
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 515f3a9cd02eb..7435647c6d9ee 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl
import scala.reflect.ClassTag
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.graphx._
-class EdgePartitionSuite extends FunSuite {
+class EdgePartitionSuite extends SparkFunSuite {
def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = {
val builder = new EdgePartitionBuilder[A, Int]
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index fe8304c1cdc32..1203f8959f506 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -17,15 +17,13 @@
package org.apache.spark.graphx.impl
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.graphx._
-class VertexPartitionSuite extends FunSuite {
+class VertexPartitionSuite extends SparkFunSuite {
test("isDefined, filter") {
val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
index 4cc30a96408f8..c965a6eb8df13 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext {
test("Grid Connected Components") {
withSpark { sc =>
@@ -52,13 +50,16 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val chain1 = (0 until 9).map(x => (x, x + 1))
val chain2 = (10 until 20).map(x => (x, x + 1))
- val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) }
val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0)
val ccGraph = twoChains.connectedComponents()
val vertices = ccGraph.vertices.collect()
for ( (id, cc) <- vertices ) {
- if(id < 10) { assert(cc === 0) }
- else { assert(cc === 10) }
+ if (id < 10) {
+ assert(cc === 0)
+ } else {
+ assert(cc === 10)
+ }
}
val ccMap = vertices.toMap
for (id <- 0 until 20) {
@@ -75,7 +76,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
withSpark { sc =>
val chain1 = (0 until 9).map(x => (x, x + 1))
val chain2 = (10 until 20).map(x => (x, x + 1))
- val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) }
val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse
val ccGraph = twoChains.connectedComponents()
val vertices = ccGraph.vertices.collect()
@@ -106,9 +107,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
(4L, ("peter", "student"))))
// Create an RDD for edges
val relationships: RDD[Edge[String]] =
- sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"),
+ sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"),
Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"),
- Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague")))
+ Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague")))
// Edges are:
// 2 ---> 5 ---> 3
// | \
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
index 61fd0c4605568..808877f0590f8 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
-class LabelPropagationSuite extends FunSuite with LocalSparkContext {
+class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext {
test("Label Propagation") {
withSpark { sc =>
// Construct a graph with two cliques connected by a single edge
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
index 3f3c9dfd7b3dd..45f1e3011035e 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
@@ -31,14 +30,14 @@ object GridPageRank {
def sub2ind(r: Int, c: Int): Int = r * nCols + c
// Make the grid graph
for (r <- 0 until nRows; c <- 0 until nCols) {
- val ind = sub2ind(r,c)
+ val ind = sub2ind(r, c)
if (r + 1 < nRows) {
outDegree(ind) += 1
- inNbrs(sub2ind(r + 1,c)) += ind
+ inNbrs(sub2ind(r + 1, c)) += ind
}
if (c + 1 < nCols) {
outDegree(ind) += 1
- inNbrs(sub2ind(r,c + 1)) += ind
+ inNbrs(sub2ind(r, c + 1)) += ind
}
}
// compute the pagerank
@@ -57,7 +56,7 @@ object GridPageRank {
}
-class PageRankSuite extends FunSuite with LocalSparkContext {
+class PageRankSuite extends SparkFunSuite with LocalSparkContext {
def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
@@ -99,8 +98,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
val resetProb = 0.15
val errorTol = 1.0e-5
- val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices
- val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb)
+ val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices
+ val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb)
.vertices.cache()
// Static PageRank should only take 2 iterations to converge
@@ -117,7 +116,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
}
assert(staticErrors.sum === 0)
- val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache()
+ val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache()
assert(compareRanks(staticRanks2, dynamicRanks) < errorTol)
}
} // end of test Star PageRank
@@ -162,7 +161,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext {
test("Chain PersonalizedPageRank") {
withSpark { sc =>
val chain1 = (0 until 9).map(x => (x, x + 1) )
- val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) }
+ val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) }
val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
val resetProb = 0.15
val tol = 0.0001
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
index 7bd6b7f3c4ab2..2991438f5e57e 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
-class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
+class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext {
test("Test SVD++ with mean square error on training set") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
index f2c38e79c452c..d7eaa70ce6407 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.lib._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class ShortestPathsSuite extends FunSuite with LocalSparkContext {
+class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext {
test("Shortest Path Computations") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
index 1f658c371ffcf..d6b03208180db 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
@@ -17,16 +17,14 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._
-class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext {
test("Island Strongly Connected Components") {
withSpark { sc =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
index 293c7f3ba4c21..c47552cf3a3bd 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.graphx.lib
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx._
import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut
-class TriangleCountSuite extends FunSuite with LocalSparkContext {
+class TriangleCountSuite extends SparkFunSuite with LocalSparkContext {
test("Count a single triangle") {
withSpark { sc =>
@@ -58,7 +57,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext {
val triangles =
Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
Array(0L -> -1L, -1L -> -2L, -2L -> 0L)
- val revTriangles = triangles.map { case (a,b) => (b,a) }
+ val revTriangles = triangles.map { case (a, b) => (b, a) }
val rawEdges = sc.parallelize(triangles ++ revTriangles, 2)
val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
val triangleCount = graph.triangleCount()
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
index f3b3738db0dad..186d0cc2a977b 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.graphx.util
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class BytecodeUtilsSuite extends FunSuite {
+class BytecodeUtilsSuite extends SparkFunSuite {
import BytecodeUtilsSuite.TestClass
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
index 8d9c8ddccbb3c..32e0c841c6997 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.graphx.util
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.LocalSparkContext
-class GraphGeneratorsSuite extends FunSuite with LocalSparkContext {
+class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext {
test("GraphGenerators.generateRandomEdges") {
val src = 5
diff --git a/launcher/pom.xml b/launcher/pom.xml
index ebfa7685eaa18..48dd0d5f9106b 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,14 +22,14 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xmlorg.apache.sparkspark-launcher_2.10jar
- Spark Launcher Project
+ Spark Project Launcherhttp://spark.apache.org/launcher
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 33fd813f7a86c..33d65d13f0d25 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -296,6 +296,9 @@ Properties loadPropertiesFile() throws IOException {
try {
fd = new FileInputStream(propsFile);
props.load(new InputStreamReader(fd, "UTF-8"));
+ for (Map.Entry
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-streaming_${scala.binary.version}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 9e16e60270141..e9a5d7c0e7988 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -19,15 +19,15 @@ package org.apache.spark.ml
import scala.annotation.varargs
-import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.param.{ParamMap, ParamPair}
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
* Abstract class for estimators that fit models to data.
*/
-@AlphaComponent
+@DeveloperApi
abstract class Estimator[M <: Model[M]] extends PipelineStage {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 70e7495ac616c..186bf7ae7a2f6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -17,16 +17,16 @@
package org.apache.spark.ml
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.ParamMap
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
* A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
*
* @tparam M model type
*/
-@AlphaComponent
+@DeveloperApi
abstract class Model[M <: Model[M]] extends Transformer {
/**
* The parent estimator that produced this model.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 43bee1b770e67..a9bd28df71ee1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -17,20 +17,23 @@
package org.apache.spark.ml
+import java.{util => ju}
+
+import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
* A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
*/
-@AlphaComponent
+@DeveloperApi
abstract class PipelineStage extends Params with Logging {
/**
@@ -69,7 +72,7 @@ abstract class PipelineStage extends Params with Logging {
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each
* of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the
* stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will
@@ -80,7 +83,7 @@ abstract class PipelineStage extends Params with Logging {
* transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
* an identity transformer.
*/
-@AlphaComponent
+@Experimental
class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
def this() = this(Identifiable.randomUID("pipeline"))
@@ -97,12 +100,9 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
/** @group getParam */
def getStages: Array[PipelineStage] = $(stages).clone()
- override def validateParams(paramMap: ParamMap): Unit = {
- val map = extractParamMap(paramMap)
- getStages.foreach {
- case pStage: Params => pStage.validateParams(map)
- case _ =>
- }
+ override def validateParams(): Unit = {
+ super.validateParams()
+ $(stages).foreach(_.validateParams())
}
/**
@@ -169,15 +169,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Represents a fitted pipeline.
*/
-@AlphaComponent
+@Experimental
class PipelineModel private[ml] (
override val uid: String,
val stages: Array[Transformer])
extends Model[PipelineModel] with Logging {
+ /** A Java/Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
+ this(uid, stages.asScala.toArray)
+ }
+
override def validateParams(): Unit = {
super.validateParams()
stages.foreach(_.validateParams())
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index ec0f76aa668bd..e752b81a14282 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -58,7 +58,6 @@ private[ml] trait PredictorParams extends Params
/**
* :: DeveloperApi ::
- *
* Abstraction for prediction problems (regression and classification).
*
* @tparam FeaturesType Type of features.
@@ -113,7 +112,6 @@ abstract class Predictor[
*
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
*/
- @DeveloperApi
private[ml] def featuresDataType: DataType = new VectorUDT
override def transformSchema(schema: StructType): StructType = {
@@ -134,7 +132,6 @@ abstract class Predictor[
/**
* :: DeveloperApi ::
- *
* Abstraction for a model for prediction tasks (regression and classification).
*
* @tparam FeaturesType Type of features.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 38bb6a5a5391e..f07f733a5ddb5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml
import scala.annotation.varargs
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.sql.DataFrame
@@ -28,10 +28,10 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
* Abstract class for transformers that transform one dataset into another.
*/
-@AlphaComponent
+@DeveloperApi
abstract class Transformer extends PipelineStage {
/**
@@ -73,10 +73,12 @@ abstract class Transformer extends PipelineStage {
}
/**
+ * :: DeveloperApi ::
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/
-private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
+@DeveloperApi
+abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index f5f37aa77929c..457c15830fd38 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -19,10 +19,12 @@ package org.apache.spark.ml.attribute
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}
/**
+ * :: DeveloperApi ::
* Attributes that describe a vector ML column.
*
* @param name name of the attribute group (the ML column name)
@@ -31,6 +33,7 @@ import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}
* @param attrs optional array of attributes. Attribute will be copied with their corresponding
* indices in the array.
*/
+@DeveloperApi
class AttributeGroup private (
val name: String,
val numAttributes: Option[Int],
@@ -182,7 +185,11 @@ class AttributeGroup private (
}
}
-/** Factory methods to create attribute groups. */
+/**
+ * :: DeveloperApi ::
+ * Factory methods to create attribute groups.
+ */
+@DeveloperApi
object AttributeGroup {
import AttributeKeys._
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
index a83febd7de2cc..5c7089b491677 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala
@@ -17,12 +17,17 @@
package org.apache.spark.ml.attribute
+import org.apache.spark.annotation.DeveloperApi
+
/**
+ * :: DeveloperApi ::
* An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]],
* and [[AttributeType$#Binary]].
*/
+@DeveloperApi
sealed abstract class AttributeType(val name: String)
+@DeveloperApi
object AttributeType {
/** Numeric type. */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index e8f7f152784a1..ce43a450daad0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -19,11 +19,14 @@ package org.apache.spark.ml.attribute
import scala.annotation.varargs
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField}
/**
+ * :: DeveloperApi ::
* Abstract class for ML attributes.
*/
+@DeveloperApi
sealed abstract class Attribute extends Serializable {
name.foreach { n =>
@@ -135,6 +138,10 @@ private[attribute] trait AttributeFactory {
}
}
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
object Attribute extends AttributeFactory {
private[attribute] override def fromMetadata(metadata: Metadata): Attribute = {
@@ -163,6 +170,7 @@ object Attribute extends AttributeFactory {
/**
+ * :: DeveloperApi ::
* A numeric attribute with optional summary statistics.
* @param name optional name
* @param index optional index
@@ -171,6 +179,7 @@ object Attribute extends AttributeFactory {
* @param std optional standard deviation
* @param sparsity optional sparsity (ratio of zeros)
*/
+@DeveloperApi
class NumericAttribute private[ml] (
override val name: Option[String] = None,
override val index: Option[Int] = None,
@@ -278,8 +287,10 @@ class NumericAttribute private[ml] (
}
/**
+ * :: DeveloperApi ::
* Factory methods for numeric attributes.
*/
+@DeveloperApi
object NumericAttribute extends AttributeFactory {
/** The default numeric attribute. */
@@ -298,6 +309,7 @@ object NumericAttribute extends AttributeFactory {
}
/**
+ * :: DeveloperApi ::
* A nominal attribute.
* @param name optional name
* @param index optional index
@@ -306,6 +318,7 @@ object NumericAttribute extends AttributeFactory {
* defined.
* @param values optional values. At most one of `numValues` and `values` can be defined.
*/
+@DeveloperApi
class NominalAttribute private[ml] (
override val name: Option[String] = None,
override val index: Option[Int] = None,
@@ -430,7 +443,11 @@ class NominalAttribute private[ml] (
}
}
-/** Factory methods for nominal attributes. */
+/**
+ * :: DeveloperApi ::
+ * Factory methods for nominal attributes.
+ */
+@DeveloperApi
object NominalAttribute extends AttributeFactory {
/** The default nominal attribute. */
@@ -450,11 +467,13 @@ object NominalAttribute extends AttributeFactory {
}
/**
+ * :: DeveloperApi ::
* A binary attribute.
* @param name optional name
* @param index optional index
* @param values optionla values. If set, its size must be 2.
*/
+@DeveloperApi
class BinaryAttribute private[ml] (
override val name: Option[String] = None,
override val index: Option[Int] = None,
@@ -526,7 +545,11 @@ class BinaryAttribute private[ml] (
}
}
-/** Factory methods for binary attributes. */
+/**
+ * :: DeveloperApi ::
+ * Factory methods for binary attributes.
+ */
+@DeveloperApi
object BinaryAttribute extends AttributeFactory {
/** The default binary attribute. */
@@ -543,8 +566,10 @@ object BinaryAttribute extends AttributeFactory {
}
/**
+ * :: DeveloperApi ::
* An unresolved attribute.
*/
+@DeveloperApi
object UnresolvedAttribute extends Attribute {
override def attrType: AttributeType = AttributeType.Unresolved
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 7c961332bf5b6..8030e0728a56c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -17,10 +17,10 @@
package org.apache.spark.ml.classification
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
+import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -31,14 +31,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
* for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
-@AlphaComponent
+@Experimental
final class DecisionTreeClassifier(override val uid: String)
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {
@@ -89,19 +88,19 @@ final class DecisionTreeClassifier(override val uid: String)
}
}
+@Experimental
object DecisionTreeClassifier {
/** Accessor for supported impurities: entropy, gini */
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
-@AlphaComponent
+@Experimental
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
override val rootNode: Node)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index d504d84beb91e..62f4b51f770e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -20,11 +20,11 @@ package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -36,14 +36,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* learning algorithm for classification.
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
*/
-@AlphaComponent
+@Experimental
final class GBTClassifier(override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTParams with TreeClassifierParams with Logging {
@@ -144,6 +143,7 @@ final class GBTClassifier(override val uid: String)
}
}
+@Experimental
object GBTClassifier {
// The losses below should be lowercase.
/** Accessor for supported loss settings: logistic */
@@ -151,8 +151,7 @@ object GBTClassifier {
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* model for classification.
* It supports binary labels, as well as both continuous and categorical features.
@@ -160,7 +159,7 @@ object GBTClassifier {
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
-@AlphaComponent
+@Experimental
final class GBTClassificationModel(
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
@@ -209,7 +208,7 @@ private[ml] object GBTClassificationModel {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8694c96e4c5b6..f136bcee9cf2b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -19,11 +19,11 @@ package org.apache.spark.ml.classification
import scala.collection.mutable
-import breeze.linalg.{norm => brzNorm, DenseVector => BDV}
-import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
-import breeze.optimize.{CachedDiffFunction, DiffFunction}
+import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
@@ -35,7 +35,6 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.{SparkException, Logging}
/**
* Params for logistic regression.
@@ -45,12 +44,11 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
with HasThreshold
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Logistic regression.
* Currently, this class only supports binary classification.
*/
-@AlphaComponent
+@Experimental
class LogisticRegression(override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams with Logging {
@@ -76,7 +74,7 @@ class LogisticRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
@@ -92,7 +90,11 @@ class LogisticRegression(override val uid: String)
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
- /** @group setParam */
+ /**
+ * Whether to fit an intercept term.
+ * Default is true.
+ * @group setParam
+ * */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
@@ -221,11 +223,10 @@ class LogisticRegression(override val uid: String)
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Model produced by [[LogisticRegression]].
*/
-@AlphaComponent
+@Experimental
class LogisticRegressionModel private[ml] (
override val uid: String,
val weights: Vector,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 1543f051ccd17..825f9ed1b54b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -21,7 +21,7 @@ import java.util.UUID
import scala.language.existentials
-import org.apache.spark.annotation.{AlphaComponent, Experimental}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.Param
@@ -37,11 +37,13 @@ import org.apache.spark.storage.StorageLevel
*/
private[ml] trait OneVsRestParams extends PredictorParams {
+ // scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
type M <: ClassificationModel[F, M]
type E <: Classifier[F, E, M]
}
+ // scalastyle:on structural.type
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -54,8 +56,7 @@ private[ml] trait OneVsRestParams extends PredictorParams {
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Model produced by [[OneVsRest]].
* This stores the models resulting from training k binary classifiers: one for each class.
* Each example is scored against all k models, and the model with the highest score
@@ -67,11 +68,11 @@ private[ml] trait OneVsRestParams extends PredictorParams {
* The i-th model is produced by testing the i-th class (taking label 1) vs the rest
* (taking label 0).
*/
-@AlphaComponent
+@Experimental
final class OneVsRestModel private[ml] (
override val uid: String,
labelMetadata: Metadata,
- val models: Array[_ <: ClassificationModel[_,_]])
+ val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams {
override def transformSchema(schema: StructType): StructType = {
@@ -105,17 +106,17 @@ final class OneVsRestModel private[ml] (
// add temporary column to store intermediate scores and update
val tmpColName = "mbc$tmp" + UUID.randomUUID().toString
- val update: (Map[Int, Double], Vector) => Map[Int, Double] =
+ val update: (Map[Int, Double], Vector) => Map[Int, Double] =
(predictions: Map[Int, Double], prediction: Vector) => {
predictions + ((index, prediction(1)))
}
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
- val transformedDataset = model.transform(df).select(columns:_*)
+ val transformedDataset = model.transform(df).select(columns : _*)
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
val newColumns = origCols ++ List(col(tmpColName))
// switch out the intermediate column with the accumulator column
- updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
+ updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName)
}
if (handlePersistence) {
@@ -130,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ .drop(accColName)
}
}
@@ -191,7 +193,7 @@ final class OneVsRest(override val uid: String)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val classifier = getClassifier
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
- }.toArray[ClassificationModel[_,_]]
+ }.toArray[ClassificationModel[_, _]]
if (handlePersistence) {
multiclassLabeled.unpersist()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index a1de7919859eb..852a67e066322 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -19,10 +19,10 @@ package org.apache.spark.ml.classification
import scala.collection.mutable
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -33,14 +33,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for
* classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
-@AlphaComponent
+@Experimental
final class RandomForestClassifier(override val uid: String)
extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
@@ -100,6 +99,7 @@ final class RandomForestClassifier(override val uid: String)
}
}
+@Experimental
object RandomForestClassifier {
/** Accessor for supported impurity settings: entropy, gini */
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
@@ -110,15 +110,14 @@ object RandomForestClassifier {
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
*/
-@AlphaComponent
+@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel])
@@ -171,7 +170,7 @@ private[ml] object RandomForestClassificationModel {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index ddbdd00ceb159..f695ddaeefc72 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.evaluation
-import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
@@ -28,11 +27,10 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Evaluator for binary classification, which expects two input columns: score and label.
*/
-@AlphaComponent
+@Experimental
class BinaryClassificationEvaluator(override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index cabd1c97c085c..61e937e693699 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -17,15 +17,15 @@
package org.apache.spark.ml.evaluation
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
* Abstract class for evaluators that compute metrics from predictions.
*/
-@AlphaComponent
+@DeveloperApi
abstract class Evaluator extends Params {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 80458928c5439..abb1b35bedea5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.evaluation
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.{Param, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
@@ -26,19 +26,18 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Evaluator for regression, which expects two input columns: prediction and label.
*/
-@AlphaComponent
+@Experimental
final class RegressionEvaluator(override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol {
def this() = this(Identifiable.randomUID("regEval"))
/**
- * param for metric name in evaluation
- * @group param supports mse, rmse, r2, mae as valid metric names.
+ * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
+ * @group param
*/
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 62f4a6343423e..b06122d733853 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
@@ -28,10 +28,10 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Binarize a column of continuous features given a threshold.
*/
-@AlphaComponent
+@Experimental
final class Binarizer(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index ac8dfb5632a7b..a3d1f6f65ccaf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import java.{util => ju}
import org.apache.spark.SparkException
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
@@ -31,10 +31,10 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*/
-@AlphaComponent
+@Experimental
final class Bucketizer(override val uid: String)
extends Model[Bucketizer] with HasInputCol with HasOutputCol {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 8b32eee0e490a..1e758cb775de7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
@@ -26,12 +26,12 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a
* provided "weight" vector. In other words, it scales each column of the dataset by a scalar
* multiplier.
*/
-@AlphaComponent
+@Experimental
class ElementwiseProduct(override val uid: String)
extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
@@ -41,7 +41,7 @@ class ElementwiseProduct(override val uid: String)
* the vector to multiply with input vectors
* @group param
*/
- val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product")
+ val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product")
/** @group setParam */
def setScalingVec(value: Vector): this.type = set(scalingVec, value)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 8942d45219177..f936aef80f8af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -17,22 +17,22 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions.{udf, col}
+import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Maps a sequence of terms to their term frequencies using the hashing trick.
*/
-@AlphaComponent
+@Experimental
class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("hashingTF"))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 788c392050c2d..376b84530cd57 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -58,10 +58,10 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
-@AlphaComponent
+@Experimental
final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
def this() = this(Identifiable.randomUID("idf"))
@@ -85,10 +85,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Model fitted by [[IDF]].
*/
-@AlphaComponent
+@Experimental
class IDFModel private[ml] (
override val uid: String,
idfModel: feature.IDFModel)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index 3f689d1585cd6..8282e5ffa17f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
@@ -26,10 +26,10 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Normalize a vector to have unit norm using the given p-norm.
*/
-@AlphaComponent
+@Experimental
class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] {
def this() = this(Identifiable.randomUID("normalizer"))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 1fb9b9ae75091..8f34878c8d329 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -17,93 +17,152 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkException
-import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
-import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{DoubleType, StructType}
/**
- * A one-hot encoder that maps a column of label indices to a column of binary vectors, with
- * at most a single one-value. By default, the binary vector has an element for each category, so
- * with 5 categories, an input value of 2.0 would map to an output vector of
- * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
- * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
- * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
- * linearly dependent because they sum up to one.
+ * :: Experimental ::
+ * A one-hot encoder that maps a column of category indices to a column of binary vectors, with
+ * at most a single one-value per row that indicates the input category index.
+ * For example with 5 categories, an input value of 2.0 would map to an output vector of
+ * `[0.0, 0.0, 1.0, 0.0]`.
+ * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]]
+ * because it makes the vector entries sum up to one, and hence linearly dependent.
+ * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+ * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories.
+ * The output vectors are sparse.
+ *
+ * @see [[StringIndexer]] for converting categorical values into category indices
*/
-@AlphaComponent
-class OneHotEncoder(override val uid: String)
- extends UnaryTransformer[Double, Vector, OneHotEncoder] with HasInputCol with HasOutputCol {
+@Experimental
+class OneHotEncoder(override val uid: String) extends Transformer
+ with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("oneHot"))
/**
- * Whether to include a component in the encoded vectors for the first category, defaults to true.
+ * Whether to drop the last category in the encoded vector (default: true)
* @group param
*/
- final val includeFirst: BooleanParam =
- new BooleanParam(this, "includeFirst", "include first category")
- setDefault(includeFirst -> true)
-
- private var categories: Array[String] = _
+ final val dropLast: BooleanParam =
+ new BooleanParam(this, "dropLast", "whether to drop the last category")
+ setDefault(dropLast -> true)
/** @group setParam */
- def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
+ def setDropLast(value: Boolean): this.type = set(dropLast, value)
/** @group setParam */
- override def setInputCol(value: String): this.type = set(inputCol, value)
+ def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
- override def setOutputCol(value: String): this.type = set(outputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
- val inputFields = schema.fields
+ val is = "_is_"
+ val inputColName = $(inputCol)
val outputColName = $(outputCol)
- require(inputFields.forall(_.name != $(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
- categories = inputColAttr match {
+ SchemaUtils.checkColumnType(schema, inputColName, DoubleType)
+ val inputFields = schema.fields
+ require(!inputFields.exists(_.name == outputColName),
+ s"Output column $outputColName already exists.")
+
+ val inputAttr = Attribute.fromStructField(schema(inputColName))
+ val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
- nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
- case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
+ if (nominal.values.isDefined) {
+ nominal.values.map(_.map(v => inputColName + is + v))
+ } else if (nominal.numValues.isDefined) {
+ nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
+ } else {
+ None
+ }
+ case binary: BinaryAttribute =>
+ if (binary.values.isDefined) {
+ binary.values.map(_.map(v => inputColName + is + v))
+ } else {
+ Some(Array.tabulate(2)(i => inputColName + is + i))
+ }
+ case _: NumericAttribute =>
+ throw new RuntimeException(
+ s"The input column $inputColName cannot be numeric.")
case _ =>
- throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
+ None // optimistic about unknown attributes
}
- val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
- val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
- val outputFields = inputFields :+ attr.toStructField()
+ val filteredOutputAttrNames = outputAttrNames.map { names =>
+ if ($(dropLast)) {
+ require(names.length > 1,
+ s"The input column $inputColName should have at least two distinct values.")
+ names.dropRight(1)
+ } else {
+ names
+ }
+ }
+
+ val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
+ val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
+ BinaryAttribute.defaultAttr.withName(name)
+ }
+ new AttributeGroup($(outputCol), attrs)
+ } else {
+ new AttributeGroup($(outputCol))
+ }
+
+ val outputFields = inputFields :+ outputAttrGroup.toStructField()
StructType(outputFields)
}
- protected override def createTransformFunc(): (Double) => Vector = {
- val first = $(includeFirst)
- val vecLen = if (first) categories.length else categories.length - 1
+ override def transform(dataset: DataFrame): DataFrame = {
+ // schema transformation
+ val is = "_is_"
+ val inputColName = $(inputCol)
+ val outputColName = $(outputCol)
+ val shouldDropLast = $(dropLast)
+ var outputAttrGroup = AttributeGroup.fromStructField(
+ transformSchema(dataset.schema)(outputColName))
+ if (outputAttrGroup.size < 0) {
+ // If the number of attributes is unknown, we check the values from the input column.
+ val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0))
+ .aggregate(0.0)(
+ (m, x) => {
+ assert(x >=0.0 && x == x.toInt,
+ s"Values from column $inputColName must be indices, but got $x.")
+ math.max(m, x)
+ },
+ (m0, m1) => {
+ math.max(m0, m1)
+ }
+ ).toInt + 1
+ val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
+ val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
+ val outputAttrs: Array[Attribute] =
+ filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
+ outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
+ }
+ val metadata = outputAttrGroup.toMetadata()
+
+ // data transformation
+ val size = outputAttrGroup.size
val oneValue = Array(1.0)
val emptyValues = Array[Double]()
val emptyIndices = Array[Int]()
- label: Double => {
- val values = if (first || label != 0.0) oneValue else emptyValues
- val indices = if (first) {
- Array(label.toInt)
- } else if (label != 0.0) {
- Array(label.toInt - 1)
+ val encode = udf { label: Double =>
+ if (label < size) {
+ Vectors.sparse(size, Array(label.toInt), oneValue)
} else {
- emptyIndices
+ Vectors.sparse(size, emptyIndices, emptyValues)
}
- Vectors.sparse(vecLen, indices, values)
}
- }
- /**
- * Returns the data type of the output column.
- */
- protected def outputDataType: DataType = new VectorUDT
+ dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index 8ddf9d6a1e138..442e95820217a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
@@ -27,14 +27,14 @@ import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion,
* which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an
* expansion of a product of sums expresses it as a sum of products by using the fact that
* multiplication distributes over addition". Take a 2-variable feature vector as an example:
* `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`.
*/
-@AlphaComponent
+@Experimental
class PolynomialExpansion(override val uid: String)
extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 5ccda15d872ed..b0fd06d84fdb3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -35,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
/**
* Centers the data with mean before scaling.
- * It will build a dense output, so this does not work on sparse input
+ * It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
*/
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
-
+
/**
* Scales the data to unit standard deviation.
* Default: true
@@ -51,11 +51,11 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Standardizes features by removing the mean and scaling to unit variance using column summary
* statistics on the samples in the training set.
*/
-@AlphaComponent
+@Experimental
class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel]
with StandardScalerParams {
@@ -68,13 +68,13 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
-
+
/** @group setParam */
def setWithMean(value: Boolean): this.type = set(withMean, value)
-
+
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
-
+
override def fit(dataset: DataFrame): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
@@ -95,10 +95,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Model fitted by [[StandardScaler]].
*/
-@AlphaComponent
+@Experimental
class StandardScalerModel private[ml] (
override val uid: String,
scaler: feature.StandardScalerModel)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 3f79b67309f07..f4e250757560a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
@@ -52,13 +52,13 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* A label indexer that maps a string column of labels to an ML column of label indices.
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
*/
-@AlphaComponent
+@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
with StringIndexerBase {
@@ -86,10 +86,13 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Model fitted by [[StringIndexer]].
+ * NOTE: During transformation, if the input column does not exist,
+ * [[StringIndexerModel.transform]] would return the input dataset unmodified.
+ * This is a temporary fix for the case when target labels do not exist during prediction.
*/
-@AlphaComponent
+@Experimental
class StringIndexerModel private[ml] (
override val uid: String,
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
@@ -112,6 +115,12 @@ class StringIndexerModel private[ml] (
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = {
+ if (!dataset.schema.fieldNames.contains($(inputCol))) {
+ logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
+ "Skip StringIndexerModel.")
+ return dataset
+ }
+
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
@@ -128,6 +137,11 @@ class StringIndexerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ if (schema.fieldNames.contains($(inputCol))) {
+ validateAndTransformSchema(schema)
+ } else {
+ // If the input column does not exist during transformation, we skip StringIndexerModel.
+ schema
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 31f3a1aa4c76b..21c15b6c33f6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -17,19 +17,19 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
*
* @see [[RegexTokenizer]]
*/
-@AlphaComponent
+@Experimental
class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] {
def this() = this(Identifiable.randomUID("tok"))
@@ -46,13 +46,13 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* A regex based tokenizer that extracts tokens either by using the provided regex pattern to split
* the text (default) or repeatedly matching the regex (if `gaps` is true).
* Optional parameters also allow filtering tokens using a minimal length.
* It returns an array of strings that can be empty.
*/
-@AlphaComponent
+@Experimental
class RegexTokenizer(override val uid: String)
extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 181b62f46fce8..229ee27ec5942 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -20,8 +20,9 @@ package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
@@ -30,14 +31,14 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* A feature transformer that merges multiple columns into a vector column.
*/
-@AlphaComponent
+@Experimental
class VectorAssembler(override val uid: String)
extends Transformer with HasInputCols with HasOutputCol {
- def this() = this(Identifiable.randomUID("va"))
+ def this() = this(Identifiable.randomUID("vecAssembler"))
/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
@@ -46,19 +47,59 @@ class VectorAssembler(override val uid: String)
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = {
+ // Schema transformation.
+ val schema = dataset.schema
+ lazy val first = dataset.first()
+ val attrs = $(inputCols).flatMap { c =>
+ val field = schema(c)
+ val index = schema.fieldIndex(c)
+ field.dataType match {
+ case DoubleType =>
+ val attr = Attribute.fromStructField(field)
+ // If the input column doesn't have ML attribute, assume numeric.
+ if (attr == UnresolvedAttribute) {
+ Some(NumericAttribute.defaultAttr.withName(c))
+ } else {
+ Some(attr.withName(c))
+ }
+ case _: NumericType | BooleanType =>
+ // If the input column type is a compatible scalar type, assume numeric.
+ Some(NumericAttribute.defaultAttr.withName(c))
+ case _: VectorUDT =>
+ val group = AttributeGroup.fromStructField(field)
+ if (group.attributes.isDefined) {
+ // If attributes are defined, copy them with updated names.
+ group.attributes.get.map { attr =>
+ if (attr.name.isDefined) {
+ // TODO: Define a rigorous naming scheme.
+ attr.withName(c + "_" + attr.name.get)
+ } else {
+ attr
+ }
+ }
+ } else {
+ // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
+ // from metadata, check the first row.
+ val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
+ Array.fill(numAttrs)(NumericAttribute.defaultAttr)
+ }
+ }
+ }
+ val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
+
+ // Data transformation.
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
}
- val schema = dataset.schema
- val inputColNames = $(inputCols)
- val args = inputColNames.map { c =>
+ val args = $(inputCols).map { c =>
schema(c).dataType match {
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
- dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol)))
+
+ dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata))
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index e238fb310ed37..1d0f23b4fb3db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -22,7 +22,7 @@ import java.util.{Map => JMap}
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{IntParam, ParamValidators, Params}
@@ -56,8 +56,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Class for indexing categorical feature columns in a dataset of [[Vector]].
*
* This has 2 usage modes:
@@ -91,7 +90,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
* - Add warning if a categorical feature has only 1 category.
* - Add option for allowing unknown categories.
*/
-@AlphaComponent
+@Experimental
class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel]
with VectorIndexerParams {
@@ -230,8 +229,7 @@ private object VectorIndexer {
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Transform categorical features to use 0-based indices instead of their original values.
* - Categorical features are mapped to indices.
* - Continuous features (columns) are left unchanged.
@@ -246,7 +244,7 @@ private object VectorIndexer {
* Values are maps from original features values to 0-based category indices.
* If a feature is not in this map, it is treated as continuous.
*/
-@AlphaComponent
+@Experimental
class VectorIndexerModel private[ml] (
override val uid: String,
val numFeatures: Int,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index ed032669229ce..36f19509f0cfb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -82,11 +82,11 @@ private[feature] trait Word2VecBase extends Params
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
* natural language processing or machine learning process.
*/
-@AlphaComponent
+@Experimental
final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase {
def this() = this(Identifiable.randomUID("w2v"))
@@ -135,10 +135,10 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Model fitted by [[Word2Vec]].
*/
-@AlphaComponent
+@Experimental
class Word2VecModel private[ml] (
override val uid: String,
wordVectors: feature.Word2VecModel)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
index 00d9c802e930d..87f4223964ada 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java
+++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
@@ -16,10 +16,10 @@
*/
/**
- * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly
* assemble and configure practical machine learning pipelines.
*/
-@AlphaComponent
+@Experimental
package org.apache.spark.ml;
-import org.apache.spark.annotation.AlphaComponent;
+import org.apache.spark.annotation.Experimental;
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index ac75e9de1a8f2..c589d06d9f7e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -18,7 +18,7 @@
package org.apache.spark
/**
- * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly
* assemble and configure practical machine learning pipelines.
*
* @groupname param Parameters
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 12fc5b561f76e..ba94d6a3a80a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -24,11 +24,11 @@ import scala.annotation.varargs
import scala.collection.mutable
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.util.Identifiable
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
* A param with self-contained documentation and optionally default value. Primitive-typed param
* should use the specialized versions, which are more friendly to Java users.
*
@@ -39,7 +39,7 @@ import org.apache.spark.ml.util.Identifiable
* See [[ParamValidators]] for factory methods for common validation functions.
* @tparam T param value type
*/
-@AlphaComponent
+@DeveloperApi
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
extends Serializable {
@@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
}
}
- /**
- * Creates a param pair with the given value (for Java).
- */
+ /** Creates a param pair with the given value (for Java). */
def w(value: T): ParamPair[T] = this -> value
- /**
- * Creates a param pair with the given value (for Scala).
- */
+ /** Creates a param pair with the given value (for Scala). */
def ->(value: T): ParamPair[T] = ParamPair(this, value)
override final def toString: String = s"${parent}__$name"
@@ -174,7 +170,11 @@ object ParamValidators {
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
-/** Specialized version of [[Param[Double]]] for Java. */
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Double]]] for Java.
+ */
+@DeveloperApi
class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean)
extends Param[Double](parent, name, doc, isValid) {
@@ -186,10 +186,15 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Double): ParamPair[Double] = super.w(value)
}
-/** Specialized version of [[Param[Int]]] for Java. */
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Int]]] for Java.
+ */
+@DeveloperApi
class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean)
extends Param[Int](parent, name, doc, isValid) {
@@ -201,10 +206,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value)
}
-/** Specialized version of [[Param[Float]]] for Java. */
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Float]]] for Java.
+ */
+@DeveloperApi
class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean)
extends Param[Float](parent, name, doc, isValid) {
@@ -216,10 +226,15 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Float): ParamPair[Float] = super.w(value)
}
-/** Specialized version of [[Param[Long]]] for Java. */
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Long]]] for Java.
+ */
+@DeveloperApi
class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean)
extends Param[Long](parent, name, doc, isValid) {
@@ -231,47 +246,60 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Long): ParamPair[Long] = super.w(value)
}
-/** Specialized version of [[Param[Boolean]]] for Java. */
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Boolean]]] for Java.
+ */
+@DeveloperApi
class BooleanParam(parent: String, name: String, doc: String) // No need for isValid
extends Param[Boolean](parent, name, doc) {
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
+ /** Creates a param pair with the given value (for Java). */
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
-/** Specialized version of [[Param[Array[String]]]] for Java. */
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Array[String]]]] for Java.
+ */
+@DeveloperApi
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
extends Param[Array[String]](parent, name, doc, isValid) {
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
- override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
-
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}
-/** Specialized version of [[Param[Array[Double]]]] for Java. */
+/**
+ * :: DeveloperApi ::
+ * Specialized version of [[Param[Array[Double]]]] for Java.
+ */
+@DeveloperApi
class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean)
extends Param[Array[Double]](parent, name, doc, isValid) {
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
- override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)
-
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
- def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
+ def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
+ w(value.asScala.map(_.asInstanceOf[Double]).toArray)
}
/**
+ * :: Experimental ::
* A param amd its value.
*/
+@Experimental
case class ParamPair[T](param: Param[T], value: T) {
// This is *the* place Param.validate is called. Whenever a parameter is specified, we should
// always construct a ParamPair so that validate is called.
@@ -279,11 +307,11 @@ case class ParamPair[T](param: Param[T], value: T) {
}
/**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
* Trait for components that take parameters. This also provides an internal param map to store
* parameter values attached to the instance.
*/
-@AlphaComponent
+@DeveloperApi
trait Params extends Identifiable with Serializable {
/**
@@ -303,19 +331,6 @@ trait Params extends Identifiable with Serializable {
.map(m => m.invoke(this).asInstanceOf[Param[_]])
}
- /**
- * Validates parameter values stored internally plus the input parameter map.
- * Raises an exception if any parameter is invalid.
- *
- * This only needs to check for interactions between parameters.
- * Parameter value checks which do not depend on other parameters are handled by
- * [[Param.validate()]]. This method does not handle input/output column parameters;
- * those are checked during schema validation.
- */
- def validateParams(paramMap: ParamMap): Unit = {
- copy(paramMap).validateParams()
- }
-
/**
* Validates parameter values stored internally.
* Raise an exception if any parameter value is invalid.
@@ -541,10 +556,10 @@ trait Params extends Identifiable with Serializable {
abstract class JavaParams extends Params
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* A param to value map.
*/
-@AlphaComponent
+@Experimental
final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
extends Serializable {
@@ -665,6 +680,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
def size: Int = map.size
}
+@Experimental
object ParamMap {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 1ffb5eddc36bd..8ffbcf0d8bc71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen {
val params = Seq(
ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
- ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
+ ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)",
isValid = "ParamValidators.gtEq(0)"),
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index ed08417bd4df8..a0c8ccdac9ad9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params {
private[ml] trait HasMaxIter extends Params {
/**
- * Param for max number of iterations (>= 0).
+ * Param for maximum number of iterations (>= 0).
* @group param
*/
- final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
+ final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0))
/** @group getParam */
final def getMaxIter: Int = $(maxIter)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 2a5ddbfae5cdf..df009d855ecbb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -31,25 +31,50 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.netlib.util.intW
import org.apache.spark.{Logging, Partitioner}
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom
+/**
+ * Common params for ALS and ALSModel.
+ */
+private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
+ /**
+ * Param for the column name for user ids.
+ * Default: "user"
+ * @group param
+ */
+ val userCol = new Param[String](this, "userCol", "column name for user ids")
+
+ /** @group getParam */
+ def getUserCol: String = $(userCol)
+
+ /**
+ * Param for the column name for item ids.
+ * Default: "item"
+ * @group param
+ */
+ val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+
+ /** @group getParam */
+ def getItemCol: String = $(itemCol)
+}
+
/**
* Common params for ALS.
*/
-private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
+private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam
with HasPredictionCol with HasCheckpointInterval with HasSeed {
/**
@@ -105,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
/** @group getParam */
def getAlpha: Double = $(alpha)
- /**
- * Param for the column name for user ids.
- * Default: "user"
- * @group param
- */
- val userCol = new Param[String](this, "userCol", "column name for user ids")
-
- /** @group getParam */
- def getUserCol: String = $(userCol)
-
- /**
- * Param for the column name for item ids.
- * Default: "item"
- * @group param
- */
- val itemCol = new Param[String](this, "itemCol", "column name for item ids")
-
- /** @group getParam */
- def getItemCol: String = $(itemCol)
-
/**
* Param for the column name for ratings.
* Default: "rating"
@@ -156,58 +161,66 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- require(schema($(userCol)).dataType == IntegerType)
- require(schema($(itemCol)).dataType== IntegerType)
+ SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
+ SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
require(ratingType == FloatType || ratingType == DoubleType)
- val predictionColName = $(predictionCol)
- require(!schema.fieldNames.contains(predictionColName),
- s"Prediction column $predictionColName already exists.")
- val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false)
- StructType(newFields)
+ SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
/**
+ * :: Experimental ::
* Model fitted by ALS.
+ *
+ * @param rank rank of the matrix factorization model
+ * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features`
+ * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
*/
+@Experimental
class ALSModel private[ml] (
override val uid: String,
- k: Int,
- userFactors: RDD[(Int, Array[Float])],
- itemFactors: RDD[(Int, Array[Float])])
- extends Model[ALSModel] with ALSParams {
+ val rank: Int,
+ @transient val userFactors: DataFrame,
+ @transient val itemFactors: DataFrame)
+ extends Model[ALSModel] with ALSModelParams {
+
+ /** @group setParam */
+ def setUserCol(value: String): this.type = set(userCol, value)
+
+ /** @group setParam */
+ def setItemCol(value: String): this.type = set(itemCol, value)
/** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame): DataFrame = {
- import dataset.sqlContext.implicits._
- val users = userFactors.toDF("id", "features")
- val items = itemFactors.toDF("id", "features")
-
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
if (userFeatures != null && itemFeatures != null) {
- blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
+ blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
} else {
Float.NaN
}
}
dataset
- .join(users, dataset($(userCol)) === users("id"), "left")
- .join(items, dataset($(itemCol)) === items("id"), "left")
- .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol)))
+ .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
+ .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+ .select(dataset("*"),
+ predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
+ SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+ SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
}
}
/**
+ * :: Experimental ::
* Alternating Least Squares (ALS) matrix factorization.
*
* ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
@@ -236,6 +249,7 @@ class ALSModel private[ml] (
* indicated user
* preferences rather than explicit ratings given to items.
*/
+@Experimental
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
import org.apache.spark.ml.recommendation.ALS.Rating
@@ -295,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
}
override def fit(dataset: DataFrame): ALSModel = {
+ import dataset.sqlContext.implicits._
val ratings = dataset
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
col($(ratingCol)).cast(FloatType))
@@ -306,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
checkpointInterval = $(checkpointInterval), seed = $(seed))
- val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this)
+ val userDF = userFactors.toDF("id", "features")
+ val itemDF = itemFactors.toDF("id", "features")
+ val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
copyValues(model)
}
@@ -326,7 +343,11 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
@DeveloperApi
object ALS extends Logging {
- /** Rating class for better code readability. */
+ /**
+ * :: DeveloperApi ::
+ * Rating class for better code readability.
+ */
+ @DeveloperApi
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
/** Trait for least squares solvers applied to the normal equation. */
@@ -487,8 +508,10 @@ object ALS extends Logging {
}
/**
+ * :: DeveloperApi ::
* Implementation of the ALS algorithm.
*/
+ @DeveloperApi
def train[ID: ClassTag]( // scalastyle:ignore
ratings: RDD[Rating[ID]],
rank: Int = 10,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index e67df21b2e4ae..43b68e7bb20fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -17,10 +17,10 @@
package org.apache.spark.ml.regression
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node}
+import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -31,13 +31,12 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
* for regression.
* It supports both continuous and categorical features.
*/
-@AlphaComponent
+@Experimental
final class DecisionTreeRegressor(override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeParams with TreeRegressorParams {
@@ -79,19 +78,19 @@ final class DecisionTreeRegressor(override val uid: String)
}
}
+@Experimental
object DecisionTreeRegressor {
/** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression.
* It supports both continuous and categorical features.
* @param rootNode Root of the decision tree
*/
-@AlphaComponent
+@Experimental
final class DecisionTreeRegressionModel private[ml] (
override val uid: String,
override val rootNode: Node)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 4249ff5c1ebc7..b7e374bb6cb49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -20,10 +20,10 @@ package org.apache.spark.ml.regression
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -35,13 +35,12 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* learning algorithm for regression.
* It supports both continuous and categorical features.
*/
-@AlphaComponent
+@Experimental
final class GBTRegressor(override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
with GBTParams with TreeRegressorParams with Logging {
@@ -134,6 +133,7 @@ final class GBTRegressor(override val uid: String)
}
}
+@Experimental
object GBTRegressor {
// The losses below should be lowercase.
/** Accessor for supported loss settings: squared (L2), absolute (L1) */
@@ -141,7 +141,7 @@ object GBTRegressor {
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
*
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* model for regression.
@@ -149,7 +149,7 @@ object GBTRegressor {
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
-@AlphaComponent
+@Experimental
final class GBTRegressionModel(
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
@@ -198,7 +198,7 @@ private[ml] object GBTRegressionModel {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 3ebb78f79201a..70cd8e9e87fae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -23,7 +23,7 @@ import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
@@ -44,8 +44,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Linear regression.
*
* The learning objective is to minimize the squared error, with regularization.
@@ -58,7 +57,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
* - L1 (Lasso)
* - L2 + L1 (elastic net)
*/
-@AlphaComponent
+@Experimental
class LinearRegression(override val uid: String)
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
with LinearRegressionParams with Logging {
@@ -84,7 +83,7 @@ class LinearRegression(override val uid: String)
setDefault(elasticNetParam -> 0.0)
/**
- * Set the maximal number of iterations.
+ * Set the maximum number of iterations.
* Default is 100.
* @group setParam
*/
@@ -190,11 +189,10 @@ class LinearRegression(override val uid: String)
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* Model produced by [[LinearRegression]].
*/
-@AlphaComponent
+@Experimental
class LinearRegressionModel private[ml] (
override val uid: String,
val weights: Vector,
@@ -323,7 +321,7 @@ private class LeastSquaresAggregator(
}
(weightsArray, -sum + labelMean / labelStd, weightsArray.length)
}
-
+
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
private val gradientSumArray = Array.ofDim[Double](dim)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 82437aa8de294..49a1f7ce8c995 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -17,10 +17,10 @@
package org.apache.spark.ml.regression
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -31,12 +31,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression.
* It supports both continuous and categorical features.
*/
-@AlphaComponent
+@Experimental
final class RandomForestRegressor(override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
with RandomForestParams with TreeRegressorParams {
@@ -89,6 +88,7 @@ final class RandomForestRegressor(override val uid: String)
}
}
+@Experimental
object RandomForestRegressor {
/** Accessor for supported impurity settings: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
@@ -99,13 +99,12 @@ object RandomForestRegressor {
}
/**
- * :: AlphaComponent ::
- *
+ * :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
* @param _trees Decision trees in the ensemble.
*/
-@AlphaComponent
+@Experimental
final class RandomForestRegressionModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel])
@@ -153,7 +152,7 @@ private[ml] object RandomForestRegressionModel {
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
val newTrees = oldModel.trees.map { tree =>
- // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ // parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
new RandomForestRegressionModel(parent.uid, newTrees)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index d2dec0c76cb12..4242154be14ce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -17,14 +17,16 @@
package org.apache.spark.ml.tree
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
Node => OldNode, Predict => OldPredict}
-
/**
+ * :: DeveloperApi ::
* Decision tree node interface.
*/
+@DeveloperApi
sealed abstract class Node extends Serializable {
// TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
@@ -89,10 +91,12 @@ private[ml] object Node {
}
/**
+ * :: DeveloperApi ::
* Decision tree leaf node.
* @param prediction Prediction this node makes
* @param impurity Impurity measure at this node (for training data)
*/
+@DeveloperApi
final class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double) extends Node {
@@ -118,6 +122,7 @@ final class LeafNode private[ml] (
}
/**
+ * :: DeveloperApi ::
* Internal Decision Tree node.
* @param prediction Prediction this node would make if it were a leaf node
* @param impurity Impurity measure at this node (for training data)
@@ -127,6 +132,7 @@ final class LeafNode private[ml] (
* @param rightChild Right-hand child node
* @param split Information about the test used to split to the left or right child.
*/
+@DeveloperApi
final class InternalNode private[ml] (
override val prediction: Double,
override val impurity: Double,
@@ -153,9 +159,9 @@ final class InternalNode private[ml] (
override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
val prefix: String = " " * indentFactor
- prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" +
+ prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" +
leftChild.subtreeToString(indentFactor + 1) +
- prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" +
+ prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" +
rightChild.subtreeToString(indentFactor + 1)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 90f1d052764d3..7acdeeee72d23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -17,15 +17,18 @@
package org.apache.spark.ml.tree
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
import org.apache.spark.mllib.tree.model.{Split => OldSplit}
/**
+ * :: DeveloperApi ::
* Interface for a "Split," which specifies a test made at a decision tree node
* to choose the left or right path.
*/
+@DeveloperApi
sealed trait Split extends Serializable {
/** Index of feature which this split tests */
@@ -52,12 +55,14 @@ private[tree] object Split {
}
/**
+ * :: DeveloperApi ::
* Split which tests a categorical feature.
* @param featureIndex Index of the feature to test
* @param _leftCategories If the feature value is in this set of categories, then the split goes
* left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
+@DeveloperApi
final class CategoricalSplit private[ml] (
override val featureIndex: Int,
_leftCategories: Array[Double],
@@ -125,11 +130,13 @@ final class CategoricalSplit private[ml] (
}
/**
+ * :: DeveloperApi ::
* Split which tests a continuous feature.
* @param featureIndex Index of the feature to test
* @param threshold If the feature value is <= this threshold, then the split goes left.
* Otherwise, it goes right.
*/
+@DeveloperApi
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
extends Split {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 816fcedf2efb3..a0c5238d966bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -17,7 +17,6 @@
package org.apache.spark.ml.tree
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
@@ -26,12 +25,10 @@ import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldG
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
/**
- * :: DeveloperApi ::
* Parameters for Decision Tree-based algorithms.
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-@DeveloperApi
private[ml] trait DecisionTreeParams extends PredictorParams {
/**
@@ -265,12 +262,10 @@ private[ml] object TreeRegressorParams {
}
/**
- * :: DeveloperApi ::
* Parameters for Decision Tree-based ensemble algorithms.
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-@DeveloperApi
private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
/**
@@ -307,12 +302,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
}
/**
- * :: DeveloperApi ::
* Parameters for Random Forest algorithms.
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-@DeveloperApi
private[ml] trait RandomForestParams extends TreeEnsembleParams {
/**
@@ -377,12 +370,10 @@ private[ml] object RandomForestParams {
}
/**
- * :: DeveloperApi ::
* Parameters for Gradient-Boosted Tree algorithms.
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-@DeveloperApi
private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index e21ff94a20f54..cb29392e8bc63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.tuning
import com.github.fommil.netlib.F2jBLAS
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
@@ -79,10 +79,10 @@ private[ml] trait CrossValidatorParams extends Params {
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* K-fold cross validation.
*/
-@AlphaComponent
+@Experimental
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
with CrossValidatorParams with Logging {
@@ -102,12 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
/** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
- override def validateParams(paramMap: ParamMap): Unit = {
- getEstimatorParamMaps.foreach { eMap =>
- getEstimator.validateParams(eMap ++ paramMap)
- }
- }
-
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
@@ -141,26 +135,35 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
- copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
+ copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
$(estimator).transformSchema(schema)
}
+
+ override def validateParams(): Unit = {
+ super.validateParams()
+ val est = $(estimator)
+ for (paramMap <- $(estimatorParamMaps)) {
+ est.copy(paramMap).validateParams()
+ }
+ }
}
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Model from k-fold cross validation.
*/
-@AlphaComponent
+@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
- val bestModel: Model[_])
+ val bestModel: Model[_],
+ val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams {
- override def validateParams(paramMap: ParamMap): Unit = {
- bestModel.validateParams(paramMap)
+ override def validateParams(): Unit = {
+ bestModel.validateParams()
}
override def transform(dataset: DataFrame): DataFrame = {
@@ -171,4 +174,12 @@ class CrossValidatorModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
bestModel.transformSchema(schema)
}
+
+ override def copy(extra: ParamMap): CrossValidatorModel = {
+ val copied = new CrossValidatorModel(
+ uid,
+ bestModel.copy(extra).asInstanceOf[Model[_]],
+ avgMetrics.clone())
+ copyValues(copied, extra)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
index dafe73d82c00a..98a8f0330ca45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
@@ -20,14 +20,14 @@ package org.apache.spark.ml.tuning
import scala.annotation.varargs
import scala.collection.mutable
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param._
/**
- * :: AlphaComponent ::
+ * :: Experimental ::
* Builder for a param grid used in grid search-based model selection.
*/
-@AlphaComponent
+@Experimental
class ParamGridBuilder {
private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]]
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 2fa54df6fc2b2..8f66bc808a007 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -43,7 +43,8 @@ import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.stat.test.ChiSqTestResult
-import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
+import org.apache.spark.mllib.stat.{
+ KernelDensity, MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.loss.Losses
@@ -392,14 +393,14 @@ private[python] class PythonMLLibAPI extends Serializable {
data: JavaRDD[Vector],
wt: Vector,
mu: Array[Object],
- si: Array[Object]): RDD[Vector] = {
+ si: Array[Object]): RDD[Vector] = {
val weight = wt.toArray
val mean = mu.map(_.asInstanceOf[DenseVector])
val sigma = si.map(_.asInstanceOf[DenseMatrix])
val gaussians = Array.tabulate(weight.length){
i => new MultivariateGaussian(mean(i), sigma(i))
- }
+ }
val model = new GaussianMixtureModel(weight, gaussians)
model.predictSoft(data).map(Vectors.dense)
}
@@ -428,7 +429,7 @@ private[python] class PythonMLLibAPI extends Serializable {
if (seed != null) als.setSeed(seed)
- val model = als.run(ratingsJRDD.rdd)
+ val model = als.run(ratingsJRDD.rdd)
new MatrixFactorizationModelWrapper(model)
}
@@ -459,7 +460,7 @@ private[python] class PythonMLLibAPI extends Serializable {
if (seed != null) als.setSeed(seed)
- val model = als.run(ratingsJRDD.rdd)
+ val model = als.run(ratingsJRDD.rdd)
new MatrixFactorizationModelWrapper(model)
}
@@ -494,7 +495,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
new Normalizer(p).transform(rdd)
}
-
+
/**
* Java stub for StandardScaler.fit(). This stub returns a
* handle to the Java object instead of the content of the Java object.
@@ -945,6 +946,15 @@ private[python] class PythonMLLibAPI extends Serializable {
r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
}
+ /**
+ * Java stub for the estimate method of KernelDensity
+ */
+ def estimateKernelDensity(
+ sample: JavaRDD[Double],
+ bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
+ return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
+ points.asScala.toArray)
+ }
}
@@ -1242,7 +1252,7 @@ private[spark] object SerDe extends Serializable {
}
/* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
- def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
+ def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
rdd.map(x => Array(x._1, x._2))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index c88410ac0ff43..fc509d2ba1470 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
@@ -36,11 +37,11 @@ import org.apache.spark.util.Utils
* independent Gaussian distributions with associated "mixing" weights
* specifying each's contribution to the composite.
*
- * Given a set of sample points, this class will maximize the log-likelihood
- * for a mixture of k Gaussians, iterating until the log-likelihood changes by
+ * Given a set of sample points, this class will maximize the log-likelihood
+ * for a mixture of k Gaussians, iterating until the log-likelihood changes by
* less than convergenceTol, or until it has reached the max number of iterations.
* While this process is generally guaranteed to converge, it is not guaranteed
- * to find a global optimum.
+ * to find a global optimum.
*
* Note: For high-dimensional data (with many features), this algorithm may perform poorly.
* This is due to high-dimensional data (a) making it difficult to cluster at all (based
@@ -53,24 +54,24 @@ import org.apache.spark.util.Utils
*/
@Experimental
class GaussianMixture private (
- private var k: Int,
- private var convergenceTol: Double,
+ private var k: Int,
+ private var convergenceTol: Double,
private var maxIterations: Int,
private var seed: Long) extends Serializable {
-
+
/**
* Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01,
* maxIterations: 100, seed: random}.
*/
def this() = this(2, 0.01, 100, Utils.random.nextLong())
-
+
// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5
-
- // an initializing GMM can be provided rather than using the
+
+ // an initializing GMM can be provided rather than using the
// default random starting point
private var initialModel: Option[GaussianMixtureModel] = None
-
+
/** Set the initial GMM starting point, bypassing the random initialization.
* You must call setK() prior to calling this method, and the condition
* (model.k == this.k) must be met; failure will result in an IllegalArgumentException
@@ -83,37 +84,37 @@ class GaussianMixture private (
}
this
}
-
+
/** Return the user supplied initial GMM, if supplied */
def getInitialModel: Option[GaussianMixtureModel] = initialModel
-
+
/** Set the number of Gaussians in the mixture model. Default: 2 */
def setK(k: Int): this.type = {
this.k = k
this
}
-
+
/** Return the number of Gaussians in the mixture model */
def getK: Int = k
-
+
/** Set the maximum number of iterations to run. Default: 100 */
def setMaxIterations(maxIterations: Int): this.type = {
this.maxIterations = maxIterations
this
}
-
+
/** Return the maximum number of iterations to run */
def getMaxIterations: Int = maxIterations
-
+
/**
- * Set the largest change in log-likelihood at which convergence is
+ * Set the largest change in log-likelihood at which convergence is
* considered to have occurred.
*/
def setConvergenceTol(convergenceTol: Double): this.type = {
this.convergenceTol = convergenceTol
this
}
-
+
/**
* Return the largest change in log-likelihood at which convergence is
* considered to have occurred.
@@ -132,41 +133,41 @@ class GaussianMixture private (
/** Perform expectation maximization */
def run(data: RDD[Vector]): GaussianMixtureModel = {
val sc = data.sparkContext
-
+
// we will operate on the data as breeze data
val breezeData = data.map(_.toBreeze).cache()
-
+
// Get length of the input vectors
val d = breezeData.first().length
-
+
// Determine initial weights and corresponding Gaussians.
// If the user supplied an initial GMM, we use those values, otherwise
// we start with uniform weights, a random mean from the data, and
// diagonal covariance matrices using component variances
- // derived from the samples
+ // derived from the samples
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weights, gmm.gaussians)
-
+
case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
- (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
+ (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
- new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
+ new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
})
}
}
-
- var llh = Double.MinValue // current log-likelihood
+
+ var llh = Double.MinValue // current log-likelihood
var llhp = 0.0 // previous log-likelihood
-
+
var iter = 0
while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) {
// create and broadcast curried cluster contribution function
val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
-
+
// aggregate the cluster contribution for all sample points
val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
-
+
// Create new distributions based on the partial assignments
// (often referred to as the "M" step in literature)
val sumWeights = sums.weights.sum
@@ -179,22 +180,25 @@ class GaussianMixture private (
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
i = i + 1
}
-
+
llhp = llh // current becomes previous
llh = sums.logLikelihood // this is the freshly computed log-likelihood
iter += 1
- }
-
+ }
+
new GaussianMixtureModel(weights, gaussians)
}
-
+
+ /** Java-friendly version of [[run()]] */
+ def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)
+
/** Average of dense breeze vectors */
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
val v = BDV.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
- v / x.length.toDouble
+ v / x.length.toDouble
}
-
+
/**
* Construct matrix where diagonal entries are element-wise
* variance of input vectors (computes biased variance)
@@ -210,14 +214,14 @@ class GaussianMixture private (
// companion class to provide zero constructor for ExpectationSum
private object ExpectationSum {
def zero(k: Int, d: Int): ExpectationSum = {
- new ExpectationSum(0.0, Array.fill(k)(0.0),
- Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
+ new ExpectationSum(0.0, Array.fill(k)(0.0),
+ Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d)))
}
-
+
// compute cluster contributions for each input point
// (U, T) => U for aggregation
def add(
- weights: Array[Double],
+ weights: Array[Double],
dists: Array[MultivariateGaussian])
(sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
val p = weights.zip(dists).map {
@@ -235,7 +239,7 @@ private object ExpectationSum {
i = i + 1
}
sums
- }
+ }
}
// Aggregation class for partial expectation results
@@ -244,9 +248,9 @@ private class ExpectationSum(
val weights: Array[Double],
val means: Array[BDV[Double]],
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
-
+
val k = weights.length
-
+
def +=(x: ExpectationSum): ExpectationSum = {
var i = 0
while (i < k) {
@@ -257,5 +261,5 @@ private class ExpectationSum(
}
logLikelihood += x.logLikelihood
this
- }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 86353aed81156..cb807c8038101 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
@@ -34,10 +35,10 @@ import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
*
- * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
- * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
- * the respective mean and covariance for each Gaussian distribution i=1..k.
- *
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
+ * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are
+ * the respective mean and covariance for each Gaussian distribution i=1..k.
+ *
* @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is
* the weight for Gaussian i, and weights.sum == 1
* @param gaussians Array of MultivariateGaussian where gaussians(i) represents
@@ -45,9 +46,9 @@ import org.apache.spark.sql.{SQLContext, Row}
*/
@Experimental
class GaussianMixtureModel(
- val weights: Array[Double],
- val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
-
+ val weights: Array[Double],
+ val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {
+
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
override protected def formatVersion = "1.0"
@@ -64,20 +65,24 @@ class GaussianMixtureModel(
val responsibilityMatrix = predictSoft(points)
responsibilityMatrix.map(r => r.indexOf(r.max))
}
-
+
+ /** Java-friendly version of [[predict()]] */
+ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
+ predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
+
/**
* Given the input vectors, return the membership value of each vector
- * to all mixture components.
+ * to all mixture components.
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val bcDists = sc.broadcast(gaussians)
val bcWeights = sc.broadcast(weights)
- points.map { x =>
+ points.map { x =>
computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}
-
+
/**
* Compute the partial assignments for each vector
*/
@@ -89,7 +94,7 @@ class GaussianMixtureModel(
val p = weights.zip(dists).map {
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt)
}
- val pSum = p.sum
+ val pSum = p.sum
for (i <- 0 until k) {
p(i) /= pSum
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 6cf26445f20a0..974b26924dfb8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
import org.apache.spark.rdd.RDD
@@ -345,6 +346,11 @@ class DistributedLDAModel private (
}
}
+ /** Java-friendly version of [[topicDistributions]] */
+ def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
+ JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
+ }
+
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 6fa2fe053c6a4..8e5154b902d1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -273,7 +273,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
* Default: 1024, following the original Online LDA paper.
*/
def setTau0(tau0: Double): this.type = {
- require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0")
+ require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0")
this.tau0 = tau0
this
}
@@ -339,7 +339,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
override private[clustering] def initialize(
docs: RDD[(Long, Vector)],
- lda: LDA): OnlineLDAOptimizer = {
+ lda: LDA): OnlineLDAOptimizer = {
this.k = lda.getK
this.corpusSize = docs.count()
this.vocabSize = docs.first()._2.size
@@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
* uses digamma which is accurate but expensive.
*/
private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = {
- val rowSum = sum(alpha(breeze.linalg.*, ::))
+ val rowSum = sum(alpha(breeze.linalg.*, ::))
val digAlpha = digamma(alpha)
val digRowSum = digamma(rowSum)
val result = digAlpha(::, breeze.linalg.*) - digRowSum
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 1ed01c9d8ba0b..e7a243f854e33 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] (
import org.apache.spark.mllib.clustering.PowerIterationClustering._
/** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100,
- * initMode: "random"}.
+ * initMode: "random"}.
*/
def this() = this(k = 2, maxIterations = 100, initMode = "random")
@@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging {
/**
* Generates random vertex properties (v0) to start power iteration.
- *
+ *
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing a random vector
* with unit 1-norm
@@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging {
* Generates the degree vector as the vertex properties (v0) to start power iteration.
* It is not exactly the node degrees but just the normalized sum similarities. Call it
* as degree vector because it is used in the PIC paper.
- *
+ *
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing the degree vector
*/
@@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging {
val v0 = g.vertices.mapValues(_ / sum)
GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges)
}
-
+
/**
* Runs power iteration.
* @param g input graph with edges representing the normalized affinity matrix (W) and vertices
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 812014a041719..d9b34cec64894 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -21,8 +21,10 @@ import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaSparkContext._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -178,7 +180,7 @@ class StreamingKMeans(
/** Set the decay factor directly (for forgetful algorithms). */
def setDecayFactor(a: Double): this.type = {
- this.decayFactor = decayFactor
+ this.decayFactor = a
this
}
@@ -234,6 +236,9 @@ class StreamingKMeans(
}
}
+ /** Java-friendly version of `trainOn`. */
+ def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)
+
/**
* Use the clustering model to make predictions on batches of data from a DStream.
*
@@ -245,6 +250,11 @@ class StreamingKMeans(
data.map(model.predict)
}
+ /** Java-friendly version of `predictOn`. */
+ def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
+ JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
+ }
+
/**
* Use the model to make predictions on the values of a DStream and carry over its keys.
*
@@ -257,6 +267,14 @@ class StreamingKMeans(
data.mapValues(model.predict)
}
+ /** Java-friendly version of `predictOnValues`. */
+ def predictOnValues[K](
+ data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
+ implicit val tag = fakeClassTag[K]
+ JavaPairDStream.fromPairDStream(
+ predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]])
+ }
+
/** Check whether cluster centers have been initialized. */
private[this] def assertInitialized(): Unit = {
if (model.clusterCenters == null) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 9cc2d0ffcab7d..5f8c1dea237b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -108,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf
* (ordered by statistic value descending)
*/
@Experimental
-class ChiSqSelector (val numTopFeatures: Int) {
+class ChiSqSelector (val numTopFeatures: Int) extends Serializable {
/**
* Returns a ChiSquared feature selector.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
index b0985baf9b278..d67fe6c3ee4f8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala
@@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._
* Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a
* provided "weight" vector. In other words, it scales each column of the dataset by a scalar
* multiplier.
- * @param scalingVector The values used to scale the reference vector's individual components.
+ * @param scalingVec The values used to scale the reference vector's individual components.
*/
@Experimental
-class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
+class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer {
/**
* Does the hadamard product transformation.
@@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
* @return transformed vector.
*/
override def transform(vector: Vector): Vector = {
- require(vector.size == scalingVector.size,
- s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}")
+ require(vector.size == scalingVec.size,
+ s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}")
vector match {
case dv: DenseVector =>
val values: Array[Double] = dv.values.clone()
- val dim = scalingVector.size
+ val dim = scalingVec.size
var i = 0
while (i < dim) {
- values(i) *= scalingVector(i)
+ values(i) *= scalingVec(i)
i += 1
}
Vectors.dense(values)
@@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer {
val dim = values.length
var i = 0
while (i < dim) {
- values(i) *= scalingVector(indices(i))
+ values(i) *= scalingVec(indices(i))
i += 1
}
Vectors.sparse(size, indices, values)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index a89eea0e21be2..efbfeb4059f5a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -144,7 +144,7 @@ private object IDF {
* Since arrays are initialized to 0 by default,
* we just omit changing those entries.
*/
- if(df(j) >= minDocFreq) {
+ if (df(j) >= minDocFreq) {
inv(j) = math.log((m + 1.0) / (df(j) + 1.0))
}
j += 1
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 6ae6917eae595..c73b8f258060d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -90,7 +90,7 @@ class StandardScalerModel (
@DeveloperApi
def setWithMean(withMean: Boolean): this.type = {
- require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null")
+ require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null")
this.withMean = withMean
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 9106b73dfcd76..51546d41c36a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -42,32 +42,32 @@ import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.sql.{SQLContext, Row}
/**
- * Entry in vocabulary
+ * Entry in vocabulary
*/
private case class VocabWord(
var word: String,
var cn: Int,
var point: Array[Int],
var code: Array[Int],
- var codeLen:Int
+ var codeLen: Int
)
/**
* :: Experimental ::
* Word2Vec creates vector representation of words in a text corpus.
* The algorithm first constructs a vocabulary from the corpus
- * and then learns vector representation of words in the vocabulary.
- * The vector representation can be used as features in
+ * and then learns vector representation of words in the vocabulary.
+ * The vector representation can be used as features in
* natural language processing and machine learning algorithms.
- *
- * We used skip-gram model in our implementation and hierarchical softmax
+ *
+ * We used skip-gram model in our implementation and hierarchical softmax
* method to train the model. The variable names in the implementation
* matches the original C implementation.
*
- * For original C implementation, see https://code.google.com/p/word2vec/
- * For research papers, see
+ * For original C implementation, see https://code.google.com/p/word2vec/
+ * For research papers, see
* Efficient Estimation of Word Representations in Vector Space
- * and
+ * and
* Distributed Representations of Words and Phrases and their Compositionality.
*/
@Experimental
@@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging {
private var numIterations = 1
private var seed = Utils.random.nextLong()
private var minCount = 5
-
+
/**
* Sets vector size (default: 100).
*/
@@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging {
this
}
- /**
- * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
+ /**
+ * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
* model's vocabulary (default: 5).
*/
def setMinCount(minCount: Int): this.type = {
this.minCount = minCount
this
}
-
+
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
@@ -150,13 +150,13 @@ class Word2Vec extends Serializable with Logging {
.map(x => VocabWord(
x._1,
x._2,
- new Array[Int](MAX_CODE_LENGTH),
- new Array[Int](MAX_CODE_LENGTH),
+ new Array[Int](MAX_CODE_LENGTH),
+ new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn)
-
+
vocabSize = vocab.length
require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
"the setting of minCount, which could be large enough to remove all your words in sentences.")
@@ -198,8 +198,8 @@ class Word2Vec extends Serializable with Logging {
}
var pos1 = vocabSize - 1
var pos2 = vocabSize
-
- var min1i = 0
+
+ var min1i = 0
var min2i = 0
a = 0
@@ -268,15 +268,15 @@ class Word2Vec extends Serializable with Logging {
val words = dataset.flatMap(x => x)
learnVocab(words)
-
+
createBinaryTree()
-
+
val sc = dataset.context
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
-
+
val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>
new Iterator[Array[Int]] {
def hasNext: Boolean = iter.hasNext
@@ -297,7 +297,7 @@ class Word2Vec extends Serializable with Logging {
}
}
}
-
+
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
@@ -402,7 +402,7 @@ class Word2Vec extends Serializable with Logging {
}
}
newSentences.unpersist()
-
+
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
@@ -469,7 +469,7 @@ class Word2VecModel private[mllib] (
val norm1 = blas.snrm2(n, v1, 1)
val norm2 = blas.snrm2(n, v2, 1)
if (norm1 == 0 || norm2 == 0) return 0.0
- blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
+ blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2
}
override protected def formatVersion = "1.0"
@@ -480,7 +480,7 @@ class Word2VecModel private[mllib] (
/**
* Transforms a word to its vector representation
- * @param word a word
+ * @param word a word
* @return vector representation of word
*/
def transform(word: String): Vector = {
@@ -495,18 +495,18 @@ class Word2VecModel private[mllib] (
/**
* Find synonyms of a word
* @param word a word
- * @param num number of synonyms to find
+ * @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
- findSynonyms(vector,num)
+ findSynonyms(vector, num)
}
/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
- * @param num number of synonyms to find
+ * @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index ec38529cf8fae..557119f7b1cd1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
_nativeBLAS
}
-
+
/**
* A := alpha * x * x^T^ + A
* @param alpha a real scalar that will be multiplied to x * x^T^.
@@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging {
j += 1
}
i += 1
- }
+ }
}
private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
@@ -505,7 +505,7 @@ private[spark] object BLAS extends Serializable with Logging {
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}
-
+
/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A and `SparseVector` x.
@@ -557,7 +557,7 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
}
-
+
/**
* y := alpha * A * x + beta * y
* For `SparseMatrix` A and `SparseVector` x.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
index 866936aa4f118..ae3ba3099c878 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition {
require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE,
s"k = $k and/or n = $n are too large to compute an eigendecomposition")
-
+
var ido = new intW(0)
var info = new intW(0)
var resid = new Array[Double](n)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index f6bcdf83cd337..2ffa497a99d93 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
}
override def serialize(obj: Any): Row = {
- val row = new GenericMutableRow(4)
obj match {
case SparseVector(size, indices, values) =>
+ val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
row.update(2, indices.toSeq)
row.update(3, values.toSeq)
+ row
case DenseVector(values) =>
+ val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
row.update(3, values.toSeq)
+ row
+ // TODO: There are bugs in UDT serialization because we don't have a clear separation between
+ // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
+ // TODO: deserialize may get called twice. See SPARK-7186.
+ case row: Row =>
+ row
}
- row
}
override def deserialize(datum: Any): Vector = {
datum match {
- // TODO: something wrong with UDT serialization
- case v: Vector =>
- v
case row: Row =>
require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
@@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val values = row.getAs[Iterable[Double]](3).toArray
new DenseVector(values)
}
+ // TODO: There are bugs in UDT serialization because we don't have a clear separation between
+ // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
+ // TODO: deserialize may get called twice. See SPARK-7186.
+ case v: Vector =>
+ v
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 9a89a6f3a515f..1626da9c3d2ee 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -219,7 +219,7 @@ class RowMatrix(
val computeMode = mode match {
case "auto" =>
- if(k > 5000) {
+ if (k > 5000) {
logWarning(s"computing svd with k=$k and n=$n, please check necessity")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 4b7d0589c973b..06e45e10c5bf4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -179,7 +179,7 @@ object GradientDescent extends Logging {
* if it's L2 updater; for L1 updater, the same logic is followed.
*/
var regVal = updater.compute(
- weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
+ weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
for (i <- 1 to numIterations) {
val bcWeights = data.context.broadcast(weights)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
index 34b447584e521..622b53a252ac5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala
@@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
* PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel
*/
private[mllib] class BinaryClassificationPMMLModelExport(
- model : GeneralizedLinearModel,
+ model : GeneralizedLinearModel,
description : String,
normalizationMethod : RegressionNormalizationMethodType,
- threshold: Double)
+ threshold: Double)
extends PMMLModelExport {
populateBinaryClassificationPMML()
@@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport(
.withUsageType(FieldUsageType.ACTIVE))
regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
}
-
+
// add target field
val targetField = FieldName.create("target")
dataDictionary
@@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport(
miningSchema
.withMiningFields(new MiningField(targetField)
.withUsageType(FieldUsageType.TARGET))
-
+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
-
+
pmml.setDataDictionary(dataDictionary)
pmml.withModels(regressionModel)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
index ebdeae50bb32f..c5fdecd3ca17f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala
@@ -25,7 +25,7 @@ import scala.beans.BeanProperty
import org.dmg.pmml.{Application, Header, PMML, Timestamp}
private[mllib] trait PMMLModelExport {
-
+
/**
* Holder of the exported model in PMML format
*/
@@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport {
val pmml: PMML = new PMML
setHeader(pmml)
-
+
private def setHeader(pmml: PMML): Unit = {
val version = getClass.getPackage.getImplementationVersion
val app = new Application().withName("Apache Spark MLlib").withVersion(version)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
index c16e83d6a067d..29bd689e1185a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala
@@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.RidgeRegressionModel
private[mllib] object PMMLModelExportFactory {
-
+
/**
- * Factory object to help creating the necessary PMMLModelExport implementation
+ * Factory object to help creating the necessary PMMLModelExport implementation
* taking as input the machine learning model (for example KMeansModel).
*/
def createPMMLModelExport(model: Any): PMMLModelExport = {
@@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory {
new GeneralizedLinearPMMLModelExport(lasso, "lasso regression")
case svm: SVMModel =>
new BinaryClassificationPMMLModelExport(
- svm, "linear SVM", RegressionNormalizationMethodType.NONE,
+ svm, "linear SVM", RegressionNormalizationMethodType.NONE,
svm.getThreshold.getOrElse(0.0))
case logistic: LogisticRegressionModel =>
if (logistic.numClasses == 2) {
@@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory {
"PMML Export not supported for model: " + model.getClass.getName)
}
}
-
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index 8341bb86afd71..174d5e0f6c9f0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -52,7 +52,7 @@ object RandomRDDs {
numPartitions: Int = 0,
seed: Long = Utils.random.nextLong()): RDD[Double] = {
val uniform = new UniformGenerator()
- randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed)
+ randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed)
}
/**
@@ -234,7 +234,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param shape shape parameter (> 0) for the gamma distribution
- * @param scale scale parameter (> 0) for the gamma distribution
+ * @param scale scale parameter (> 0) for the gamma distribution
* @param size Size of the RDD.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
* @param seed Random seed (default: a random long integer).
@@ -293,7 +293,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param mean mean for the log normal distribution
- * @param std standard deviation for the log normal distribution
+ * @param std standard deviation for the log normal distribution
* @param size Size of the RDD.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
* @param seed Random seed (default: a random long integer).
@@ -671,7 +671,7 @@ object RandomRDDs {
*
* @param sc SparkContext used to create the RDD.
* @param shape shape parameter (> 0) for the gamma distribution.
- * @param scale scale parameter (> 0) for the gamma distribution.
+ * @param scale scale parameter (> 0) for the gamma distribution.
* @param numRows Number of Vectors in the RDD.
* @param numCols Number of elements in each Vector.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index dddefe1944e9d..93290e6508529 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -175,7 +175,7 @@ class ALS private (
/**
* :: DeveloperApi ::
* Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default
- * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
+ * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g.
* `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement,
* at the cost of speed.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 26be30ff9d6fd..6709bd79bc820 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -195,11 +195,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
val initialWeights = {
if (numOfLinearPredictor == 1) {
- Vectors.dense(new Array[Double](numFeatures))
+ Vectors.zeros(numFeatures)
} else if (addIntercept) {
- Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor))
+ Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
} else {
- Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor))
+ Vectors.zeros(numFeatures * numOfLinearPredictor)
}
}
run(input, initialWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index 3ea63dd8c0acd..f3b46c75c05f3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -170,15 +170,15 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
case class Data(boundary: Double, prediction: Double)
def save(
- sc: SparkContext,
- path: String,
- boundaries: Array[Double],
- predictions: Array[Double],
+ sc: SparkContext,
+ path: String,
+ boundaries: Array[Double],
+ predictions: Array[Double],
isotonic: Boolean): Unit = {
val sqlContext = new SQLContext(sc)
val metadata = compact(render(
- ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
@@ -203,7 +203,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
implicit val formats = DefaultFormats
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
- val isotonic = (metadata \ "isotonic").extract[Boolean]
+ val isotonic = (metadata \ "isotonic").extract[Boolean]
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index cea8f3f47307b..aee51bf22d8d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -83,15 +83,7 @@ abstract class StreamingLinearAlgorithm[
throw new IllegalArgumentException("Model must be initialized before starting training.")
}
data.foreachRDD { (rdd, time) =>
- val initialWeights =
- model match {
- case Some(m) =>
- m.weights
- case None =>
- val numFeatures = rdd.first().features.size
- Vectors.dense(numFeatures)
- }
- model = Some(algorithm.run(rdd, initialWeights))
+ model = Some(algorithm.run(rdd, model.get.weights))
logInfo("Model updated at time %s".format(time.toString))
val display = model.get.weights.size match {
case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index a49153bf73c0d..235e043c7754b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -79,7 +79,7 @@ class StreamingLinearRegressionWithSGD private[mllib] (
this
}
- /** Set the initial weights. Default: [0.0, 0.0]. */
+ /** Set the initial weights. */
def setInitialWeights(initialWeights: Vector): this.type = {
this.model = Some(algorithm.createModel(initialWeights, 0.0))
this
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
index a6bfe26e1e4f5..58a50f9c19f14 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
@@ -93,7 +93,7 @@ class KernelDensity extends Serializable {
x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i))
i += 1
}
- (x._1, n)
+ (x._1, x._2 + 1)
},
(x, y) => {
blas.daxpy(n, 1.0, y._1, 1, x._1, 1)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 0b1755613aac4..d321cc554c1cc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -70,7 +70,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- val localCurrMean= currMean
+ val localCurrMean = currMean
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index b3fad0c52d655..900007ec6bc74 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.stat
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -80,6 +81,10 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)
+ /** Java-friendly version of [[corr()]] */
+ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
+ corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])
+
/**
* Compute the correlation for the input RDDs using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
@@ -96,6 +101,10 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)
+ /** Java-friendly version of [[corr()]] */
+ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
+ corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)
+
/**
* Conduct Pearson's chi-squared goodness of fit test of the observed data against the
* expected distribution.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index cd6add9d60b0d..cf51b24ff777f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils
* the event that the covariance matrix is singular, the density will be computed in a
* reduced dimensional subspace under which the distribution is supported.
* (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]])
- *
+ *
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
*/
@DeveloperApi
class MultivariateGaussian (
- val mu: Vector,
+ val mu: Vector,
val sigma: Matrix) extends Serializable {
require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
-
+
private val breezeMu = mu.toBreeze.toDenseVector
-
+
/**
* private[mllib] constructor
- *
+ *
* @param mu The mean vector of the distribution
* @param sigma The covariance matrix of the distribution
*/
private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = {
this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma))
}
-
+
/**
* Compute distribution dependent constants:
* rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t
- * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
+ * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
-
+
/** Returns density of this multivariate Gaussian at given point, x */
def pdf(x: Vector): Double = {
pdf(x.toBreeze)
}
-
+
/** Returns the log-density of this multivariate Gaussian at given point, x */
def logpdf(x: Vector): Double = {
logpdf(x.toBreeze)
}
-
+
/** Returns density of this multivariate Gaussian at given point, x */
private[mllib] def pdf(x: BV[Double]): Double = {
math.exp(logpdf(x))
}
-
+
/** Returns the log-density of this multivariate Gaussian at given point, x */
private[mllib] def logpdf(x: BV[Double]): Double = {
val delta = x - breezeMu
val v = rootSigmaInv * delta
u + v.t * v * -0.5
}
-
+
/**
* Calculate distribution dependent components used for the density function:
* pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
* where k is length of the mean vector.
- *
- * We here compute distribution-fixed parts
+ *
+ * We here compute distribution-fixed parts
* log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
* and
* D^(-1/2)^ * U, where sigma = U * D * U.t
- *
+ *
* Both the determinant and the inverse can be computed from the singular value decomposition
* of sigma. Noting that covariance matrices are always symmetric and positive semi-definite,
* we can use the eigendecomposition. We also do not compute the inverse directly; noting
- * that
- *
+ * that
+ *
* sigma = U * D * U.t
- * inv(Sigma) = U * inv(D) * U.t
+ * inv(Sigma) = U * inv(D) * U.t
* = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U)
- *
+ *
* and thus
- *
+ *
* -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^
- *
- * To guard against singular covariance matrices, this method computes both the
+ *
+ * To guard against singular covariance matrices, this method computes both the
* pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered
* to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and
* relation to the maximum singular value (same tolerance used by, e.g., Octave).
*/
private def calculateCovarianceConstants: (DBM[Double], Double) = {
val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t
-
+
// For numerical stability, values are considered to be non-zero only if they exceed tol.
// This prevents any inverted value from exceeding (eps * n * max(d))^-1
val tol = MLUtils.EPSILON * max(d) * d.length
-
+
try {
// log(pseudo-determinant) is sum of the logs of all non-zero singular values
val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum
-
- // calculate the root-pseudo-inverse of the diagonal matrix of singular values
+
+ // calculate the root-pseudo-inverse of the diagonal matrix of singular values
// by inverting the square root of all non-zero values
val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray))
-
+
(pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma))
} catch {
case uex: UnsupportedOperationException =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
index e597fce2babd1..23c8d7c7c8075 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
@@ -196,7 +196,7 @@ private[stat] object ChiSqTest extends Logging {
* Pearson's independence test on the input contingency matrix.
* TODO: optimize for SparseMatrix when it becomes supported.
*/
- def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = {
+ def chiSquaredMatrix(counts: Matrix, methodName: String = PEARSON.name): ChiSqTestResult = {
val method = methodFromString(methodName)
val numRows = counts.numRows
val numCols = counts.numCols
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index dfe3a0b6913ef..cecd1fed896d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -169,7 +169,7 @@ object DecisionTree extends Serializable with Logging {
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
- categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
+ categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).run(input)
@@ -768,7 +768,7 @@ object DecisionTree extends Serializable with Logging {
*/
private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
+ rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
val predict = calculatePredict(parentNodeAgg)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 1f779584dcffd..a835f96d5d0e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -60,12 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
- case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
+ case Regression =>
+ GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput,
- remappedInput, boostingStrategy, validate=false)
+ GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -93,8 +93,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
- case Regression => GradientBoostedTrees.boost(
- input, validationInput, boostingStrategy, validate=true)
+ case Regression =>
+ GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
@@ -102,7 +102,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate=true)
+ validate = true)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging {
logInfo(s"$timer")
if (persistedInput) input.unpersist()
-
+
if (validate) {
new GradientBoostedTreesModel(
boostingStrategy.treeStrategy.algo,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index b347c450c1aa8..069959976a188 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -249,7 +249,7 @@ private class RandomForest (
try {
nodeIdCache.get.deleteAllCheckpoints()
} catch {
- case e:IOException =>
+ case e: IOException =>
logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
}
}
@@ -474,7 +474,7 @@ object RandomForest extends Serializable with Logging {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 431a839817eac..a6d1398fc267b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -83,7 +83,7 @@ class Node (
def predict(features: Vector) : Double = {
if (isLeaf) {
predict.predict
- } else{
+ } else {
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
@@ -151,9 +151,9 @@ class Node (
s"(feature ${split.feature} > ${split.threshold})"
}
case Categorical => if (left) {
- s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})"
+ s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})"
} else {
- s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})"
+ s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})"
}
}
}
@@ -161,9 +161,9 @@ class Node (
if (isLeaf) {
prefix + s"Predict: ${predict.predict}\n"
} else {
- prefix + s"If ${splitToString(split.get, left=true)}\n" +
+ prefix + s"If ${splitToString(split.get, left = true)}\n" +
leftNode.get.subtreeToString(indentFactor + 1) +
- prefix + s"Else ${splitToString(split.get, left=false)}\n" +
+ prefix + s"Else ${splitToString(split.get, left = false)}\n" +
rightNode.get.subtreeToString(indentFactor + 1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
index 0c5b4f9d04a74..bd73a866c8a82 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
@@ -82,8 +82,7 @@ object MFDataGenerator {
BLAS.gemm(z, A, B, 1.0, fullData)
val df = rank * (m + n - rank)
- val sampSize = scala.math.min(scala.math.round(trainSampFact * df),
- scala.math.round(.99 * m * n)).toInt
+ val sampSize = math.min(math.round(trainSampFact * df), math.round(.99 * m * n)).toInt
val rand = new Random()
val mn = m * n
val shuffled = rand.shuffle((0 until mn).toList)
@@ -102,8 +101,8 @@ object MFDataGenerator {
// optionally generate testing data
if (test) {
- val testSampSize = scala.math
- .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt
+ val testSampSize = math.min(
+ math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt
val testOmega = shuffled.slice(sampSize, sampSize + testSampSize)
val testOrdered = testOmega.sortWith(_ < _).toArray
val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 681f4c618d302..52d6468a72af7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -82,6 +82,18 @@ object MLUtils {
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, "indices should be one-based and in ascending order" )
+ previous = current
+ i += 1
+ }
+
(label, indices.toArray, values.toArray)
}
@@ -265,7 +277,7 @@ object MLUtils {
}
Vectors.fromBreeze(vector1)
}
-
+
/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
new file mode 100644
index 0000000000000..d5bd230a957a1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaBucketizerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void bucketizerTest() {
+ double[] splits = {-0.5, 0.0, 0.5};
+
+ JavaRDD data = jsc.parallelize(Lists.newArrayList(
+ RowFactory.create(-0.5),
+ RowFactory.create(-0.3),
+ RowFactory.create(0.0),
+ RowFactory.create(0.2)
+ ));
+ StructType schema = new StructType(new StructField[] {
+ new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
+ });
+ DataFrame dataset = jsql.createDataFrame(data, schema);
+
+ Bucketizer bucketizer = new Bucketizer()
+ .setInputCol("feature")
+ .setOutputCol("result")
+ .setSplits(splits);
+
+ Row[] result = bucketizer.transform(dataset).select("result").collect();
+
+ for (Row r : result) {
+ double index = r.getDouble(0);
+ Assert.assertTrue((index >= 0) && (index <= 1));
+ }
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index da2218056307e..599e9cfd23ad4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -55,9 +55,9 @@ public void tearDown() {
@Test
public void hashingTF() {
JavaRDD jrdd = jsc.parallelize(Lists.newArrayList(
- RowFactory.create(0, "Hi I heard about Spark"),
- RowFactory.create(0, "I wish Java could use case classes"),
- RowFactory.create(1, "Logistic regression models are neat")
+ RowFactory.create(0.0, "Hi I heard about Spark"),
+ RowFactory.create(0.0, "I wish Java could use case classes"),
+ RowFactory.create(1.0, "Logistic regression models are neat")
));
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
new file mode 100644
index 0000000000000..35b18c5308f61
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+public class JavaStringIndexerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ sqlContext = null;
+ }
+
+ @Test
+ public void testStringIndexer() {
+ StructType schema = createStructType(new StructField[] {
+ createStructField("id", IntegerType, false),
+ createStructField("label", StringType, false)
+ });
+ JavaRDD rdd = jsc.parallelize(
+ Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
+ DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+
+ StringIndexer indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex");
+ DataFrame output = indexer.fit(dataset).transform(dataset);
+
+ Assert.assertArrayEquals(
+ new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) },
+ output.orderBy("id").select("id", "labelIndex").collect());
+ }
+
+ /** An alias for RowFactory.create. */
+ private Row c(Object... values) {
+ return RowFactory.create(values);
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
new file mode 100644
index 0000000000000..b7c564caad3bd
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.Arrays;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+public class JavaVectorAssemblerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext sqlContext;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
+ sqlContext = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void testVectorAssembler() {
+ StructType schema = createStructType(new StructField[] {
+ createStructField("id", IntegerType, false),
+ createStructField("x", DoubleType, false),
+ createStructField("y", new VectorUDT(), false),
+ createStructField("name", StringType, false),
+ createStructField("z", new VectorUDT(), false),
+ createStructField("n", LongType, false)
+ });
+ Row row = RowFactory.create(
+ 0, 0.0, Vectors.dense(1.0, 2.0), "a",
+ Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
+ JavaRDD rdd = jsc.parallelize(Arrays.asList(row));
+ DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+ VectorAssembler assembler = new VectorAssembler()
+ .setInputCols(new String[] {"x", "y", "z", "n"})
+ .setOutputCol("features");
+ DataFrame output = assembler.transform(dataset);
+ Assert.assertEquals(
+ Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
+ output.select("features").first().getAs(0));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
index e7df10dfa63ac..9890155e9f865 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -50,6 +50,7 @@ public void testParams() {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a");
+ Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 947ae3a2ce06f..ff5929235ac2c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -51,7 +51,8 @@ public String uid() {
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
public JavaTestParams setMyIntParam(int value) {
- set(myIntParam_, value); return this;
+ set(myIntParam_, value);
+ return this;
}
private DoubleParam myDoubleParam_;
@@ -60,7 +61,8 @@ public JavaTestParams setMyIntParam(int value) {
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
public JavaTestParams setMyDoubleParam(double value) {
- set(myDoubleParam_, value); return this;
+ set(myDoubleParam_, value);
+ return this;
}
private Param myStringParam_;
@@ -69,7 +71,18 @@ public JavaTestParams setMyDoubleParam(double value) {
public String getMyStringParam() { return getOrDefault(myStringParam_); }
public JavaTestParams setMyStringParam(String value) {
- set(myStringParam_, value); return this;
+ set(myStringParam_, value);
+ return this;
+ }
+
+ private DoubleArrayParam myDoubleArrayParam_;
+ public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
+
+ public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }
+
+ public JavaTestParams setMyDoubleArrayParam(double[] value) {
+ set(myDoubleArrayParam_, value);
+ return this;
}
private void init() {
@@ -79,8 +92,14 @@ private void init() {
List validStrings = Lists.newArrayList("a", "b");
myStringParam_ = new Param(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
- setDefault(myIntParam_, 1);
- setDefault(myDoubleParam_, 0.5);
+ myDoubleArrayParam_ =
+ new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
+
+ setDefault(myIntParam(), 1);
+ setDefault(myIntParam().w(1));
+ setDefault(myDoubleParam(), 0.5);
setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
+ setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
+ setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
index 67c262d0f9d8d..928301523fba9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.util
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class IdentifiableSuite extends FunSuite {
+class IdentifiableSuite extends SparkFunSuite {
import IdentifiableSuite.Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
similarity index 95%
rename from mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
rename to mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
index 640d2ec55e4e7..55787f8606d48 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.classification;
+package org.apache.spark.mllib.classification;
import java.io.Serializable;
import java.util.List;
@@ -28,7 +28,6 @@
import org.junit.Test;
import org.apache.spark.SparkConf;
-import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
new file mode 100644
index 0000000000000..467a7a69e8f30
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+public class JavaGaussianMixtureSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGaussianMixture");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runGaussianMixture() {
+ List points = Lists.newArrayList(
+ Vectors.dense(1.0, 2.0, 6.0),
+ Vectors.dense(1.0, 3.0, 0.0),
+ Vectors.dense(1.0, 4.0, 6.0)
+ );
+
+ JavaRDD data = sc.parallelize(points, 2);
+ GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
+ .run(data);
+ assertEquals(model.gaussians().length, 2);
+ JavaRDD predictions = model.predict(data);
+ predictions.first();
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index 96c2da169961f..581c033f08ebe 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -107,6 +107,10 @@ public void distributedLDAModel() {
// Check: log probabilities
assert(model.logLikelihood() < 0.0);
assert(model.logPrior() < 0.0);
+
+ // Check: topic distributions
+ JavaPairRDD topicDistributions = model.javaTopicDistributions();
+ assertEquals(topicDistributions.count(), corpus.count());
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
new file mode 100644
index 0000000000000..3b0e879eec77f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.apache.spark.streaming.JavaTestUtils.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+public class JavaStreamingKMeansSuite implements Serializable {
+
+ protected transient JavaStreamingContext ssc;
+
+ @Before
+ public void setUp() {
+ SparkConf conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("test")
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
+ ssc = new JavaStreamingContext(conf, new Duration(1000));
+ ssc.checkpoint("checkpoint");
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void javaAPI() {
+ List trainingBatch = Lists.newArrayList(
+ Vectors.dense(1.0),
+ Vectors.dense(0.0));
+ JavaDStream training =
+ attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
+ List> testBatch = Lists.newArrayList(
+ new Tuple2(10, Vectors.dense(1.0)),
+ new Tuple2(11, Vectors.dense(0.0)));
+ JavaPairDStream test = JavaPairDStream.fromJavaDStream(
+ attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
+ StreamingKMeans skmeans = new StreamingKMeans()
+ .setK(1)
+ .setDecayFactor(1.0)
+ .setInitialCenters(new Vector[]{Vectors.dense(1.0)}, new double[]{0.0});
+ skmeans.trainOn(training);
+ JavaPairDStream prediction = skmeans.predictOnValues(test);
+ attachTestOutputStream(prediction.count());
+ runStreams(ssc, 2, 2);
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
new file mode 100644
index 0000000000000..62f7f26b7c98f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaStatisticsSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaStatistics");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void testCorr() {
+ JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0));
+ JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3));
+
+ Double corr1 = Statistics.corr(x, y);
+ Double corr2 = Statistics.corr(x, y, "pearson");
+ // Check default method
+ assertEquals(corr1, corr2);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 2b04a3034782e..29394fefcbc43 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,15 +17,17 @@
package org.apache.spark.ml
+import scala.collection.JavaConverters._
+
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar.mock
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.DataFrame
-class PipelineSuite extends FunSuite {
+class PipelineSuite extends SparkFunSuite {
abstract class MyModel extends Model[MyModel]
@@ -81,4 +83,19 @@ class PipelineSuite extends FunSuite {
pipeline.fit(dataset)
}
}
+
+ test("pipeline model constructors") {
+ val transform0 = mock[Transformer]
+ val model1 = mock[MyModel]
+
+ val stages = Array(transform0, model1)
+ val pipelineModel0 = new PipelineModel("pipeline0", stages)
+ assert(pipelineModel0.uid === "pipeline0")
+ assert(pipelineModel0.stages === stages)
+
+ val stagesAsList = stages.toList.asJava
+ val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList)
+ assert(pipelineModel1.uid === "pipeline1")
+ assert(pipelineModel1.stages === stages)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 17ddd335deb6d..512cffb1acb66 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.attribute
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class AttributeGroupSuite extends FunSuite {
+class AttributeGroupSuite extends SparkFunSuite {
test("attribute group") {
val attrs = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index ec9b717e41ce8..72b575d022547 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.ml.attribute
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
-class AttributeSuite extends FunSuite {
+class AttributeSuite extends SparkFunSuite {
test("default numeric attribute") {
val attr: NumericAttribute = NumericAttribute.defaultAttr
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 3fdc66be8a314..ae40b0b8ff854 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
@@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeClassifierSuite.compareAPIs
@@ -251,7 +250,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
*/
}
-private[ml] object DecisionTreeClassifierSuite extends FunSuite {
+private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
/**
* Train 2 decision trees on the given dataset, one using the old API and one using the new API.
@@ -266,7 +265,7 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite {
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index ea86867f1161a..1302da3c373ff 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import GBTClassifierSuite.compareAPIs
@@ -128,7 +127,7 @@ private object GBTClassifierSuite {
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9f77d5f3efc55..a755cac3ea76e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@transient var binaryDataset: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 770b56890fa45..1d04ccb509057 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -30,7 +29,7 @@ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@transient var rdd: RDD[LabeledPoint] = _
@@ -94,6 +93,15 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
+
+ test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+ val logReg = new LogisticRegression()
+ .setMaxIter(1)
+ val ovr = new OneVsRest()
+ .setClassifier(logReg)
+ val output = ovr.fit(dataset).transform(dataset)
+ assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index cdbbacab8e0e3..eee9355a67be3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.classification
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
@@ -32,7 +31,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestClassifierSuite.compareAPIs
@@ -158,7 +157,7 @@ private object RandomForestClassifierSuite {
data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
val newModel = rf.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 3ea7aad5274f2..36a1ac6b7996d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.ml.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
+class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Regression Evaluator: default params") {
/**
@@ -39,7 +38,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
val dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
-
+
/**
* Using the following R code to load the data, train the model and evaluate metrics.
*
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 8f6c6b39dc93b..7953bd0417191 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
+class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var data: Array[Double] = _
@@ -48,7 +47,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext {
test("Binarize continuous features with setter") {
val threshold: Double = 0.2
- val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
+ val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0)
val dataFrame: DataFrame = sqlContext.createDataFrame(
data.zip(thresholdBinarized)).toDF("feature", "expected")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 0391bd8427c2c..507a8a7db24c7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -19,15 +19,13 @@ package org.apache.spark.ml.feature
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
+class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Bucket continuous features, without -inf,inf") {
// Check a set of valid feature values.
@@ -110,7 +108,7 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
}
}
-private object BucketizerSuite extends FunSuite {
+private object BucketizerSuite extends SparkFunSuite {
/** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
require(feature >= splits.head)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 2e4beb0bfff63..7b2d70e644005 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -26,7 +25,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
val hashingTF = new HashingTF
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index f85e85471617a..d83772e8be755 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class IDFSuite extends FunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index 9d09f24709e23..9f03470b7f328 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var data: Array[Vector] = _
@transient var dataFrame: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 056b9eda86bba..2e5036a844562 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.col
-
-class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
+class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
@@ -36,15 +36,16 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
indexer.transform(df)
}
- test("OneHotEncoder includeFirst = true") {
+ test("OneHotEncoder dropLast = false") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setInputCol("labelIndex")
.setOutputCol("labelVec")
+ .setDropLast(false)
val encoded = encoder.transform(transformed)
val output = encoded.select("id", "labelVec").map { r =>
- val vec = r.get(1).asInstanceOf[Vector]
+ val vec = r.getAs[Vector](1)
(r.getInt(0), vec(0), vec(1), vec(2))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
@@ -53,22 +54,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
assert(output === expected)
}
- test("OneHotEncoder includeFirst = false") {
+ test("OneHotEncoder dropLast = true") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
- .setIncludeFirst(false)
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)
val output = encoded.select("id", "labelVec").map { r =>
- val vec = r.get(1).asInstanceOf[Vector]
+ val vec = r.getAs[Vector](1)
(r.getInt(0), vec(0), vec(1))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
- val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
- (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
+ val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
+ (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
assert(output === expected)
}
+ test("input column with ML attribute") {
+ val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
+ val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size")
+ .select(col("size").as("size", attr.toMetadata()))
+ val encoder = new OneHotEncoder()
+ .setInputCol("size")
+ .setOutputCol("encoded")
+ val output = encoder.transform(df)
+ val group = AttributeGroup.fromStructField(output.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
+ }
+
+ test("input column without ML attribute") {
+ val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index")
+ val encoder = new OneHotEncoder()
+ .setInputCol("index")
+ .setOutputCol("encoded")
+ val output = encoder.transform(df)
+ val group = AttributeGroup.fromStructField(output.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index aa230ca073d5b..feca866cd711d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,15 +17,15 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
import org.scalatest.exceptions.TestFailedException
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext {
+class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Polynomial expansion with default parameter") {
val data = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 89c2fe45573aa..5f557e16e5150 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
+class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
@@ -61,4 +60,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
+
+ test("StringIndexerModel should keep silent if the input column does not exist.") {
+ val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ val df = sqlContext.range(0L, 10L)
+ assert(indexerModel.transform(df).eq(df))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index eabda089d0988..ac279cb3215c2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -19,15 +19,14 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
-class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
+class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._
test("RegexTokenizer") {
@@ -60,7 +59,7 @@ class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object RegexTokenizerSuite extends FunSuite {
+object RegexTokenizerSuite extends SparkFunSuite {
def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index d0cd62c5e4864..489abb5af7130 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions.col
-class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
+class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
@@ -61,4 +61,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
}
}
+
+ test("ML attributes") {
+ val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari")
+ val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0)
+ val user = new AttributeGroup("user", Array(
+ NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"),
+ NumericAttribute.defaultAttr.withName("salary")))
+ val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0)))
+ val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad")
+ .select(
+ col("browser").as("browser", browser.toMetadata()),
+ col("hour").as("hour", hour.toMetadata()),
+ col("count"), // "count" is an integer column without ML attribute
+ col("user").as("user", user.toMetadata()),
+ col("ad")) // "ad" is a vector column without ML attribute
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("browser", "hour", "count", "user", "ad"))
+ .setOutputCol("features")
+ val output = assembler.transform(df)
+ val schema = output.schema
+ val features = AttributeGroup.fromStructField(schema("features"))
+ assert(features.size === 7)
+ val browserOut = features.getAttr(0)
+ assert(browserOut === browser.withIndex(0).withName("browser"))
+ val hourOut = features.getAttr(1)
+ assert(hourOut === hour.withIndex(1).withName("hour"))
+ val countOut = features.getAttr(2)
+ assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2))
+ val userGenderOut = features.getAttr(3)
+ assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
+ val userSalaryOut = features.getAttr(4)
+ assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
+ assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
+ assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index b11b029c6343e..06affc7305cf5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
import scala.beans.{BeanInfo, BeanProperty}
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
+class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
import VectorIndexerSuite.FeatureData
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 43a09cc418703..94ebc3aebfa37 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
-class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Word2Vec") {
val sqlContext = new SQLContext(sc)
@@ -35,9 +34,9 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
val codes = Map(
- "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451),
- "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342),
- "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351)
+ "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
+ "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
+ "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
)
val expected = doc.map { sentence =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 1505ad872536b..778abcba22c10 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -19,8 +19,7 @@ package org.apache.spark.ml.impl
import scala.collection.JavaConverters._
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.tree._
@@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame}
-private[ml] object TreeTests extends FunSuite {
+private[ml] object TreeTests extends SparkFunSuite {
/**
* Convert the given data to a DataFrame, and set the features and label metadata.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index d270ad7613af1..96094d7a099aa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.param
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
-class ParamsSuite extends FunSuite {
+class ParamsSuite extends SparkFunSuite {
test("param") {
val solver = new TestParams()
@@ -27,7 +27,7 @@ class ParamsSuite extends FunSuite {
import solver.{maxIter, inputCol}
assert(maxIter.name === "maxIter")
- assert(maxIter.doc === "max number of iterations (>= 0)")
+ assert(maxIter.doc === "maximum number of iterations (>= 0)")
assert(maxIter.parent === uid)
assert(maxIter.toString === s"${uid}__maxIter")
assert(!maxIter.isValid(-1))
@@ -36,7 +36,7 @@ class ParamsSuite extends FunSuite {
solver.setMaxIter(5)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)")
assert(inputCol.toString === s"${uid}__inputCol")
@@ -120,7 +120,7 @@ class ParamsSuite extends FunSuite {
intercept[NoSuchElementException](solver.getInputCol)
assert(solver.explainParam(maxIter) ===
- "maxIter: max number of iterations (>= 0) (default: 10, current: 100)")
+ "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
assert(solver.explainParams() ===
Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
@@ -135,7 +135,7 @@ class ParamsSuite extends FunSuite {
intercept[IllegalArgumentException] {
solver.validateParams()
}
- solver.validateParams(ParamMap(inputCol -> "input"))
+ solver.copy(ParamMap(inputCol -> "input")).validateParams()
solver.setInputCol("input")
assert(solver.isSet(inputCol))
assert(solver.isDefined(inputCol))
@@ -202,7 +202,7 @@ class ParamsSuite extends FunSuite {
}
}
-object ParamsSuite extends FunSuite {
+object ParamsSuite extends SparkFunSuite {
/**
* Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
index ca18fa1ad3c15..eb5408d3fee7c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.ml.param.shared
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.Params
-class SharedParamsSuite extends FunSuite {
+class SharedParamsSuite extends SparkFunSuite {
test("outputCol") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 9a35555e52b90..2e5cfe7027eb6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
import com.github.fommil.netlib.BLAS.{getInstance => blas}
-import org.scalatest.FunSuite
-import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -36,7 +35,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.Utils
-class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
+class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
private var tempDir: File = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 1196a772dfdd4..33aa9d0d62343 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
@@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeRegressorSuite.compareAPIs
@@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
// TODO: test("model save/load") SPARK-6725
}
-private[ml] object DecisionTreeRegressorSuite extends FunSuite {
+private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
/**
* Train 2 decision trees on the given dataset, one using the old API and one using the new API.
@@ -83,7 +82,7 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite {
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newTree = dt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 40e7e3273e965..98fb3d3f5f22c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import GBTRegressorSuite.compareAPIs
@@ -129,7 +128,7 @@ private object GBTRegressorSuite {
val oldModel = oldGBT.run(data)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTRegressionModel.fromOld(
oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 50a78631fa6d6..732e2c42be144 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 3efffbb763b78..b24ecaa57c89b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.ml.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestRegressorSuite.compareAPIs
@@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
*/
}
-private object RandomForestRegressorSuite extends FunSuite {
+private object RandomForestRegressorSuite extends SparkFunSuite {
/**
* Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -114,7 +113,7 @@ private object RandomForestRegressorSuite extends FunSuite {
data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newModel = rf.fit(newData)
- // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ // Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestRegressionModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 05313d440fbf6..9b3619f0046ea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -17,15 +17,19 @@
package org.apache.spark.ml.tuning
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.types.StructType
-class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
+class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _
@@ -52,5 +56,56 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
+ assert(cvModel.avgMetrics.length === lrParamMaps.length)
+ }
+
+ test("validateParams should check estimatorParamMaps") {
+ import CrossValidatorSuite._
+
+ val est = new MyEstimator("est")
+ val eval = new MyEvaluator
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(est.inputCol, Array("input1", "input2"))
+ .build()
+
+ val cv = new CrossValidator()
+ .setEstimator(est)
+ .setEstimatorParamMaps(paramMaps)
+ .setEvaluator(eval)
+
+ cv.validateParams() // This should pass.
+
+ val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
+ cv.setEstimatorParamMaps(invalidParamMaps)
+ intercept[IllegalArgumentException] {
+ cv.validateParams()
+ }
+ }
+}
+
+object CrossValidatorSuite {
+
+ abstract class MyModel extends Model[MyModel]
+
+ class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
+
+ override def validateParams(): Unit = require($(inputCol).nonEmpty)
+
+ override def fit(dataset: DataFrame): MyModel = {
+ throw new UnsupportedOperationException
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ throw new UnsupportedOperationException
+ }
+ }
+
+ class MyEvaluator extends Evaluator {
+
+ override def evaluate(dataset: DataFrame): Double = {
+ throw new UnsupportedOperationException
+ }
+
+ override val uid: String = "eval"
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
index 20aa100112bfe..810b70049ec15 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.{ParamMap, TestParams}
-class ParamGridBuilderSuite extends FunSuite {
+class ParamGridBuilderSuite extends SparkFunSuite {
val solver = new TestParams()
import solver.{inputCol, maxIter}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index a629dba8a426f..59944416d96a6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.api.python
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating
-class PythonMLLibAPISuite extends FunSuite {
+class PythonMLLibAPISuite extends SparkFunSuite {
SerDe.initialize()
@@ -84,7 +83,7 @@ class PythonMLLibAPISuite extends FunSuite {
val smt = new SparseMatrix(
3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
- isTransposed=true)
+ isTransposed = true)
val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
assert(smt.toArray === nsmt.toArray)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 966811a5a3263..e8f3d0c4db20a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
import scala.util.Random
import scala.util.control.Breaks._
-import org.scalatest.FunSuite
import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -119,7 +119,7 @@ object LogisticRegressionSuite {
}
// Preventing the overflow when we compute the probability
val maxMargin = margins.max
- if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin
+ if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin
// Computing the probabilities for each class from the margins.
val norm = {
@@ -130,7 +130,7 @@ object LogisticRegressionSuite {
}
temp
}
- for (i <-0 until nClasses) probs(i) /= norm
+ for (i <- 0 until nClasses) probs(i) /= norm
// Compute the cumulative probability so we can generate a random number and assign a label.
for (i <- 1 until nClasses) probs(i) += probs(i - 1)
@@ -169,7 +169,7 @@ object LogisticRegressionSuite {
}
-class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
def validatePrediction(
predictions: Seq[Double],
input: Seq[LabeledPoint],
@@ -541,7 +541,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
}
-class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction using SGD optimizer") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index c111a78a55806..f7fc8730606af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -21,9 +21,8 @@ import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -86,7 +85,7 @@ object NaiveBayesSuite {
pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
}
-class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
import NaiveBayes.{Multinomial, Bernoulli}
@@ -163,7 +162,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
val theta = Array(
Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
- Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
+ Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
).map(_.map(math.log))
val testData = NaiveBayesSuite.generateNaiveBayesInput(
@@ -286,7 +285,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
+class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 6de098b383ba3..b1d78cba9e3dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -21,9 +21,8 @@ import scala.collection.JavaConversions._
import scala.util.Random
import org.jblas.DoubleMatrix
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -46,7 +45,7 @@ object SVMSuite {
nPoints: Int,
seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
- val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
+ val weightsMat = new DoubleMatrix(1, weights.length, weights : _*)
val x = Array.fill[Array[Double]](nPoints)(
Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0))
val y = x.map { xi =>
@@ -62,7 +61,7 @@ object SVMSuite {
}
-class SVMSuite extends FunSuite with MLlibTestSparkContext {
+class SVMSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -91,7 +90,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
val model = svm.run(testRDD)
val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
- val validationRDD = sc.parallelize(validationData, 2)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -117,7 +116,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
val B = -1.5
val C = 1.0
- val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
@@ -127,8 +126,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
val model = svm.run(testRDD)
- val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
- val validationRDD = sc.parallelize(validationData, 2)
+ val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
@@ -145,7 +144,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
val B = -1.5
val C = 1.0
- val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
val initialB = -1.0
val initialC = -1.0
@@ -159,8 +158,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
val model = svm.run(testRDD, initialWeights)
- val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
- val validationRDD = sc.parallelize(validationData,2)
+ val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
@@ -177,7 +176,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
val B = -1.5
val C = 1.0
- val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
val testRDDInvalid = testRDD.map { lp =>
@@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
+class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index 5683b55e8500a..e98b61e13e21f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase
-class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
+class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 30000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index f356ffa3e3a26..b218d72f1268a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.clustering
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
+class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
test("single cluster") {
val data = sc.parallelize(Array(
Vectors.dense(6.0, 9.0),
@@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
}
-
+
test("two clusters") {
val data = sc.parallelize(GaussianTestData.data)
@@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
-
+
val gmm = new GaussianMixture()
.setK(2)
.setInitialModel(initialGmm)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 0f2b26d462ad2..0dbbd7127444f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class KMeansSuite extends FunSuite with MLlibTestSparkContext {
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
@@ -75,7 +74,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
val center = Vectors.dense(1.0, 2.0, 3.0)
// Make sure code runs.
- var model = KMeans.train(data, k=2, maxIterations=1)
+ var model = KMeans.train(data, k = 2, maxIterations = 1)
assert(model.clusterCenters.size === 2)
}
@@ -87,7 +86,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
2)
// Make sure code runs.
- var model = KMeans.train(data, k=3, maxIterations=1)
+ var model = KMeans.train(data, k = 3, maxIterations = 1)
assert(model.clusterCenters.size === 3)
}
@@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object KMeansSuite extends FunSuite {
+object KMeansSuite extends SparkFunSuite {
def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
val singlePoint = isSparse match {
case true =>
@@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite {
}
}
-class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {
+class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index d5b7d96335744..406affa25539d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM}
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class LDASuite extends FunSuite with MLlibTestSparkContext {
+class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
import LDASuite._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 6d6fe6fe46bab..19e65f1b53ab5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable
import scala.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
+class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.mllib.clustering.PowerIterationClustering._
@@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
predictions(a.cluster) += a.id
}
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
-
+
val model2 = new PowerIterationClustering()
.setK(2)
.setInitializationMode("degree")
@@ -94,11 +92,13 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
*/
val similarities = Seq[(Long, Long, Double)](
(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0))
+ // scalastyle:off
val expected = Array(
Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0),
Array(1.0/2.0, 0.0, 1.0/2.0, 0.0),
Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0),
Array(1.0/2.0, 0.0, 1.0/2.0, 0.0))
+ // scalastyle:on
val w = normalize(sc.parallelize(similarities, 2))
w.edges.collect().foreach { case Edge(i, j, x) =>
assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14)
@@ -128,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
}
}
-object PowerIterationClusteringSuite extends FunSuite {
+object PowerIterationClusteringSuite extends SparkFunSuite {
def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
val assignments = sc.parallelize(
(0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index f90025d535e45..ac01622b8a089 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.clustering
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.TestSuiteBase
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.random.XORShiftRandom
-class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
+class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
override def maxWaitTimeMillis: Int = 30000
@@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
assert(math.abs(c1) ~== 0.8 absTol 0.6)
}
+ test("SPARK-7946 setDecayFactor") {
+ val kMeans = new StreamingKMeans()
+ assert(kMeans.decayFactor === 1.0)
+ kMeans.setDecayFactor(2.0)
+ assert(kMeans.decayFactor === 2.0)
+ }
+
def StreamingKMeansDataGenerator(
numPoints: Int,
numBatches: Int,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
index 79847633ff0dc..87ccc7eda44ea 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext {
+class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext {
test("auc computation") {
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
val auc = 4.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index e0224f960cc43..99d52fabc5309 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index 7dc4f3cfbc4e4..d55bc8c3ec09f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Multiclass evaluation metrics") {
/*
* Confusion matrix for 3-class classification with total 9 instances:
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
index 2537dd62c92f2..f3b19aeb42f84 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Multilabel evaluation metrics") {
/*
* Documents true labels (5x class0, 3x class1, 4x class2):
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
index 609eed983ff4e..c0924a213a844 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Ranking metrics: map, ndcg") {
val predictionAndLabels = sc.parallelize(
Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 670b4c34e6095..9de2bdb6d7246 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.mllib.evaluation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {
+class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("regression metrics") {
val predictionAndObservations = sc.parallelize(
- Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2)
+ Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
"explained variance regression score mismatch")
@@ -39,7 +38,7 @@ class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {
test("regression metrics with complete fitting") {
val predictionAndObservations = sc.parallelize(
- Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2)
+ Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
"explained variance regression score mismatch")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index 747f5914598ec..889727fb55823 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext {
+class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
/*
* Contingency tables
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
index f3a482abda873..ccbf8a91cdd37 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext {
+class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext {
test("elementwise (hadamard) product should properly apply vector to dense data set") {
val denseData = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index 0c4dfb7b97c7f..cf279c02334e9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("hashing tf on a single doc") {
val hashingTF = new HashingTF(1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
index 0a5cad7caf8e4..21163633051e5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class IDFSuite extends FunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
test("idf") {
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
index 5c4af2b99e68b..34122d6ed2e95 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
import breeze.linalg.{norm => brzNorm}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
val data = Array(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
index 758af588f1c69..e57f49191378f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
@@ -17,13 +17,12 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class PCASuite extends FunSuite with MLlibTestSparkContext {
+class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
private val data = Array(
Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index 7f94564b2a3ae..6ab2fa6770123 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
import org.apache.spark.rdd.RDD
-class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
// When the input data is all constant, the variance is zero. The standardization against
// zero variance is not well-defined, but we decide to just set it into zero here.
@@ -360,7 +359,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
}
withClue("model needs std and mean vectors to be equal size when both are provided") {
intercept[IllegalArgumentException] {
- val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0))
+ val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0, 1.0))
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 98a98a7599bcb..b6818369208d7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.feature
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
// TODO: add more tests
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index bd5b9cc3afa10..66ae3543ecc4e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -16,11 +16,10 @@
*/
package org.apache.spark.mllib.fpm
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
+class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
test("FP-Growth using String type") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
index 04017f67c311d..a56d7b3579213 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm
import scala.language.existentials
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class FPTreeSuite extends FunSuite with MLlibTestSparkContext {
+class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
test("add transaction") {
val tree = new FPTree[String]
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index 699f009f0f2ec..d34888af2d73b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -17,18 +17,16 @@
package org.apache.spark.mllib.impl
-import org.scalatest.FunSuite
-
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext {
+class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
import PeriodicGraphCheckpointerSuite._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 64ecd12ea7ded..b0f3f71113c57 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.linalg.BLAS._
-class BLASSuite extends FunSuite {
+class BLASSuite extends SparkFunSuite {
test("copy") {
val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0))
@@ -140,7 +139,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dA)
assert(dA ~== expected absTol 1e-15)
-
+
val dB =
new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
@@ -149,7 +148,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dB)
}
}
-
+
val dC =
new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
@@ -158,7 +157,7 @@ class BLASSuite extends FunSuite {
syr(alpha, x, dC)
}
}
-
+
val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
withClue("Size of vector must match the rank of matrix") {
@@ -256,13 +255,13 @@ class BLASSuite extends FunSuite {
val dA =
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
-
+
val dA2 =
new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true)
val sA2 =
new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0),
true)
-
+
val dx = new DenseVector(Array(1.0, 2.0, 3.0))
val sx = dx.toSparse
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
@@ -271,7 +270,7 @@ class BLASSuite extends FunSuite {
assert(sA.multiply(dx) ~== expected absTol 1e-15)
assert(dA.multiply(sx) ~== expected absTol 1e-15)
assert(sA.multiply(sx) ~== expected absTol 1e-15)
-
+
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
val y3 = y1.copy
@@ -288,7 +287,7 @@ class BLASSuite extends FunSuite {
val y14 = y1.copy
val y15 = y1.copy
val y16 = y1.copy
-
+
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
@@ -296,42 +295,42 @@ class BLASSuite extends FunSuite {
gemv(1.0, sA, dx, 2.0, y2)
gemv(1.0, dA, sx, 2.0, y3)
gemv(1.0, sA, sx, 2.0, y4)
-
+
gemv(1.0, dA2, dx, 2.0, y5)
gemv(1.0, sA2, dx, 2.0, y6)
gemv(1.0, dA2, sx, 2.0, y7)
gemv(1.0, sA2, sx, 2.0, y8)
-
+
gemv(2.0, dA, dx, 2.0, y9)
gemv(2.0, sA, dx, 2.0, y10)
gemv(2.0, dA, sx, 2.0, y11)
gemv(2.0, sA, sx, 2.0, y12)
-
+
gemv(2.0, dA2, dx, 2.0, y13)
gemv(2.0, sA2, dx, 2.0, y14)
gemv(2.0, dA2, sx, 2.0, y15)
gemv(2.0, sA2, sx, 2.0, y16)
-
+
assert(y1 ~== expected2 absTol 1e-15)
assert(y2 ~== expected2 absTol 1e-15)
assert(y3 ~== expected2 absTol 1e-15)
assert(y4 ~== expected2 absTol 1e-15)
-
+
assert(y5 ~== expected2 absTol 1e-15)
assert(y6 ~== expected2 absTol 1e-15)
assert(y7 ~== expected2 absTol 1e-15)
assert(y8 ~== expected2 absTol 1e-15)
-
+
assert(y9 ~== expected3 absTol 1e-15)
assert(y10 ~== expected3 absTol 1e-15)
assert(y11 ~== expected3 absTol 1e-15)
assert(y12 ~== expected3 absTol 1e-15)
-
+
assert(y13 ~== expected3 absTol 1e-15)
assert(y14 ~== expected3 absTol 1e-15)
assert(y15 ~== expected3 absTol 1e-15)
assert(y16 ~== expected3 absTol 1e-15)
-
+
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
gemv(1.0, dA.transpose, dx, 2.0, y1)
@@ -346,12 +345,12 @@ class BLASSuite extends FunSuite {
gemv(1.0, sA.transpose, sx, 2.0, y1)
}
}
-
+
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
-
+
val dATT = dAT.transpose
val sATT = sAT.transpose
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
index 2031032373971..dc04258e41d27 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM}
-class BreezeMatrixConversionSuite extends FunSuite {
+import org.apache.spark.SparkFunSuite
+
+class BreezeMatrixConversionSuite extends SparkFunSuite {
test("dense matrix to breeze") {
val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
val breeze = mat.toBreeze.asInstanceOf[BDM[Double]]
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
index 8abdac72902c6..3772c9235ad3a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.mllib.linalg
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+import org.apache.spark.SparkFunSuite
+
/**
* Test Breeze vector conversions.
*/
-class BreezeVectorConversionSuite extends FunSuite {
+class BreezeVectorConversionSuite extends SparkFunSuite {
val arr = Array(0.1, 0.2, 0.3, 0.4)
val n = 20
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 86119ec38101e..8dbb70f5d1c4c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg
import java.util.Random
import org.mockito.Mockito.when
-import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar._
import scala.collection.mutable.{Map => MutableMap}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
-class MatricesSuite extends FunSuite {
+class MatricesSuite extends SparkFunSuite {
test("dense matrix construction") {
val m = 3
val n = 2
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 24755e9ff46fc..c4ae0a16f7c04 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg
import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.util.TestingUtils._
-class VectorsSuite extends FunSuite {
+class VectorsSuite extends SparkFunSuite {
val arr = Array(0.1, 0.0, 0.3, 0.4)
val n = 4
@@ -215,13 +214,13 @@ class VectorsSuite extends FunSuite {
val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze)
- // SparseVector vs. SparseVector
- assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
+ // SparseVector vs. SparseVector
+ assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
// DenseVector vs. SparseVector
assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8)
// DenseVector vs. DenseVector
assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8)
- }
+ }
}
test("foreachActive") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index 949d1c9939570..93fe04c139b9a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed
import java.{util => ju}
import breeze.linalg.{DenseMatrix => BDM}
-import org.scalatest.FunSuite
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 5
val n = 4
@@ -57,11 +56,13 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
val random = new ju.Random()
// This should generate a 4x4 grid of 1x2 blocks.
val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12)
+ // scalastyle:off
val expected0 = Array(
Array(0, 0, 4, 4, 8, 8, 12),
Array(1, 1, 5, 5, 9, 9, 13),
Array(2, 2, 6, 6, 10, 10, 14),
Array(3, 3, 7, 7, 11, 11, 15))
+ // scalastyle:on
for (i <- 0 until 4; j <- 0 until 7) {
assert(part0.getPartition((i, j)) === expected0(i)(j))
assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
index 04b36a9ef9990..f3728cd036a3f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.linalg.Vectors
-class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 5
val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 2ab53cc13db71..4a7b99a976f0a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.linalg.distributed
-import org.scalatest.FunSuite
-
import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vectors}
-class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 4
val n = 3
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index 27bb19f472e1e..b6cb53d0c743e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed
import scala.util.Random
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd}
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
-class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
+class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
val m = 4
val n = 3
@@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext {
+class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
var mat: RowMatrix = _
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 86481c6e66200..a5a59e9fad5ae 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.mllib.optimization
import scala.collection.JavaConversions._
import scala.util.Random
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
@@ -42,7 +43,7 @@ object GradientDescentSuite {
offset: Double,
scale: Double,
nPoints: Int,
- seed: Int): Seq[LabeledPoint] = {
+ seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
@@ -61,7 +62,7 @@ object GradientDescentSuite {
}
}
-class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000
@@ -140,7 +141,7 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc
}
}
-class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext {
+class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index c8f2adcf155a7..d07b9d5b89227 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
-import org.scalatest.{FunSuite, Matchers}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
-class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
val nPoints = 10000
val A = 2.0
@@ -229,7 +230,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
}
}
-class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small") {
val m = 10
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
index 22855e4e8f247..d8f9b8c33963d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
-import org.scalatest.FunSuite
-
import org.jblas.{DoubleMatrix, SimpleBlas}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.TestingUtils._
-class NNLSSuite extends FunSuite {
+class NNLSSuite extends SparkFunSuite {
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = {
val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*)
@@ -68,12 +67,14 @@ class NNLSSuite extends FunSuite {
test("NNLS: nonnegativity constraint active") {
val n = 5
+ // scalastyle:off
val ata = new DoubleMatrix(Array(
Array( 4.377, -3.531, -1.306, -0.139, 3.418),
Array(-3.531, 4.344, 0.934, 0.305, -2.140),
Array(-1.306, 0.934, 2.644, -0.203, -0.170),
Array(-0.139, 0.305, -0.203, 5.883, 1.428),
Array( 3.418, -2.140, -0.170, 1.428, 4.684)))
+ // scalastyle:on
val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636))
val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
index 0b646cf1ce6c4..4c6e76e47419b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.RegressionModel
import org.dmg.pmml.RegressionNormalizationMethodType
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.LogisticRegressionModel
import org.apache.spark.mllib.classification.SVMModel
import org.apache.spark.mllib.util.LinearDataGenerator
-class BinaryClassificationPMMLModelExportSuite extends FunSuite {
+class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite {
test("logistic regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
@@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
// ensure logistic regression has normalization method set to LOGIT
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT)
}
-
+
test("linear SVM PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
-
+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
-
+
// assert that the PMML format is as expected
assert(svmModelExport.isInstanceOf[PMMLModelExport])
val pmml = svmModelExport.getPmml
@@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite {
// ensure linear SVM has normalization method set to NONE
assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE)
}
-
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
index f9afbd888dfc5..1d32309481787 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.RegressionModel
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator
-class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
+class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite {
test("linear regression PMML export") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
index b985d0446d7b0..b3f9750afa730 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.mllib.pmml.export
import org.dmg.pmml.ClusteringModel
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
-class KMeansPMMLModelExportSuite extends FunSuite {
+class KMeansPMMLModelExportSuite extends SparkFunSuite {
test("KMeansPMMLModelExport generate PMML format") {
val clusterCenters = Array(
@@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite {
val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
}
-
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
index f28a4ac8ad01f..af49450961750 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.pmml.export
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
import org.apache.spark.mllib.util.LinearDataGenerator
-class PMMLModelExportFactorySuite extends FunSuite {
+class PMMLModelExportFactorySuite extends SparkFunSuite {
test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {
val clusterCenters = Array(
@@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite {
test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport "
+ "when passing a LogisticRegressionModel or SVMModel") {
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
-
+
val logisticRegressionModel =
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
val logisticRegressionModelExport =
PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
-
+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport])
}
-
+
test("PMMLModelExportFactory throw IllegalArgumentException "
+ "when passing a Multinomial Logistic Regression") {
/** 3 classes, 2 features */
val multiclassLogisticRegressionModel = new LogisticRegressionModel(
- weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
+ weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
numFeatures = 2, numClasses = 3)
-
+
intercept[IllegalArgumentException] {
PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index b792d819fdabb..a5ca1518f82f5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -19,12 +19,11 @@ package org.apache.spark.mllib.random
import scala.math
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.util.StatCounter
// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
-class RandomDataGeneratorSuite extends FunSuite {
+class RandomDataGeneratorSuite extends SparkFunSuite {
def apiChecks(gen: RandomDataGenerator[Double]) {
// resetting seed should generate the same sequence of random numbers
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
index 63f2ea916d457..413db2000d6d7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.random
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
@@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter
*
* TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
*/
-class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable {
+class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable {
def testGeneratedRDD(rdd: RDD[Double],
expectedSize: Long,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
index 57216e8eb4a55..10f5a2be48f7c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.rdd
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
-class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("topByKey") {
val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5,
1), (3, 5)), 2)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 6d6c0aa5be812..bc64172614830 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.rdd
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.rdd.RDDFunctions._
-class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("sliding") {
val data = 0 until 6
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index b3798940ddc38..05b87728d6fdb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -21,9 +21,9 @@ import scala.collection.JavaConversions._
import scala.math.abs
import scala.util.Random
-import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.storage.StorageLevel
@@ -84,7 +84,7 @@ object ALSSuite {
}
-class ALSSuite extends FunSuite with MLlibTestSparkContext {
+class ALSSuite extends SparkFunSuite with MLlibTestSparkContext {
test("rank-1 matrices") {
testALS(50, 100, 1, 15, 0.7, 0.3)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
index 2c92866f3893d..2c8ed057a516a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.recommendation
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
-class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
+class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext {
val rank = 2
var userFeatures: RDD[(Int, Array[Double])] = _
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
index 3b38bdf5ef5eb..ea4f2865757c1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.mllib.regression
-import org.scalatest.{Matchers, FunSuite}
+import org.scalatest.Matchers
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
private def round(d: Double) = {
math.round(d * 100).toDouble / 100
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
index 110c44a7193fd..d8364a06de4da 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.mllib.regression
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
-class LabeledPointSuite extends FunSuite {
+class LabeledPointSuite extends SparkFunSuite {
test("parse labeled points") {
val points = Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index c9f5dc069ef2e..08a152ffc7a23 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -32,7 +31,7 @@ private object LassoSuite {
val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class LassoSuite extends FunSuite with MLlibTestSparkContext {
+class LassoSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -67,11 +66,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
- val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ val validationData = LinearDataGenerator
+ .generateLinearInput(A, Array[Double](B, C), nPoints, 17)
.map { case LabeledPoint(label, features) =>
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
}
- val validationRDD = sc.parallelize(validationData, 2)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
@@ -110,11 +110,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
- val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ val validationData = LinearDataGenerator
+ .generateLinearInput(A, Array[Double](B, C), nPoints, 17)
.map { case LabeledPoint(label, features) =>
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
}
- val validationRDD = sc.parallelize(validationData,2)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
@@ -141,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class LassoClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 3781931c2f819..f88a1c33c9f7c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression
import scala.util.Random
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -32,7 +31,7 @@ private object LinearRegressionSuite {
val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
@@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index d6c93cc0e49cd..7a781fee634c8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression
import scala.util.Random
import org.jblas.DoubleMatrix
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
MLlibTestSparkContext}
@@ -33,7 +33,7 @@ private object RidgeRegressionSuite {
val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
}
-class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
+class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = {
predictions.zip(input).map { case (prediction, expected) =>
@@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
}
-class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
+class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
test("task size should be small in both training and prediction") {
val m = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 26604dbe6c1ef..9a379406d5061 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.TestSuiteBase
-class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
+class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 20000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index d20a09b4b4925..c292ced75e870 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.mllib.stat
-import org.scalatest.FunSuite
-
import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
SpearmanCorrelation}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
+class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext {
// test input data
val xData = Array(1.0, 0.0, -2.0)
@@ -96,11 +95,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
val X = sc.parallelize(data)
val defaultMat = Statistics.corr(X)
val pearsonMat = Statistics.corr(X, "pearson")
+ // scalastyle:off
val expected = BDM(
(1.00000000, 0.05564149, Double.NaN, 0.4004714),
(0.05564149, 1.00000000, Double.NaN, 0.9135959),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
- (0.40047142, 0.91359586, Double.NaN,1.0000000))
+ (0.40047142, 0.91359586, Double.NaN, 1.0000000))
+ // scalastyle:on
assert(matrixApproxEqual(defaultMat.toBreeze, expected))
assert(matrixApproxEqual(pearsonMat.toBreeze, expected))
}
@@ -108,11 +109,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
test("corr(X) spearman") {
val X = sc.parallelize(data)
val spearmanMat = Statistics.corr(X, "spearman")
+ // scalastyle:off
val expected = BDM(
(1.0000000, 0.1054093, Double.NaN, 0.4000000),
(0.1054093, 1.0000000, Double.NaN, 0.9486833),
(Double.NaN, Double.NaN, 1.00000000, Double.NaN),
(0.4000000, 0.9486833, Double.NaN, 1.0000000))
+ // scalastyle:on
assert(matrixApproxEqual(spearmanMat.toBreeze, expected))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
index 15418e6035965..b084a5fb4313f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat
import java.util.Random
-import org.scalatest.FunSuite
-
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.test.ChiSqTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext {
+class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("chi squared pearson goodness of fit") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
index 14bb1cebf0b8f..5feccdf33681a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
@@ -18,19 +18,19 @@
package org.apache.spark.mllib.stat
import org.apache.commons.math3.distribution.NormalDistribution
-import org.scalatest.FunSuite
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
+class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
test("kernel density single sample") {
val rdd = sc.parallelize(Array(5.0))
val evaluationPoints = Array(5.0, 6.0)
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
val normal = new NormalDistribution(5.0, 3.0)
val acceptableErr = 1e-6
- assert(densities(0) - normal.density(5.0) < acceptableErr)
- assert(densities(0) - normal.density(6.0) < acceptableErr)
+ assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr)
+ assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr)
}
test("kernel density multiple samples") {
@@ -40,7 +40,9 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
val normal1 = new NormalDistribution(5.0, 3.0)
val normal2 = new NormalDistribution(10.0, 3.0)
val acceptableErr = 1e-6
- assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr)
- assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr)
+ assert(math.abs(
+ densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr)
+ assert(math.abs(
+ densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 23b0eec865de6..07efde4f5e6dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.stat
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.TestingUtils._
-class MultivariateOnlineSummarizerSuite extends FunSuite {
+class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
test("basic error handing") {
val summarizer = new MultivariateOnlineSummarizer
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
index fac2498e4dcb3..aa60deb665aeb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -17,49 +17,48 @@
package org.apache.spark.mllib.stat.distribution
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{ Vectors, Matrices }
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext {
+class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
test("univariate") {
val x1 = Vectors.dense(0.0)
val x2 = Vectors.dense(1.5)
-
+
val mu = Vectors.dense(0.0)
val sigma1 = Matrices.dense(1, 1, Array(1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5)
-
+
val sigma2 = Matrices.dense(1, 1, Array(4.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5)
}
-
+
test("multivariate") {
val x1 = Vectors.dense(0.0, 0.0)
val x2 = Vectors.dense(1.0, 1.0)
-
+
val mu = Vectors.dense(0.0, 0.0)
val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
val dist1 = new MultivariateGaussian(mu, sigma1)
assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5)
assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5)
-
+
val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
val dist2 = new MultivariateGaussian(mu, sigma2)
assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5)
assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5)
}
-
+
test("multivariate degenerate") {
val x1 = Vectors.dense(0.0, 0.0)
val x2 = Vectors.dense(1.0, 1.0)
-
+
val mu = Vectors.dense(0.0, 0.0)
val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0))
val dist = new MultivariateGaussian(mu, sigma)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index ce983eb27fa35..356d957f15909 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
-class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
+class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
// Tests examining individual elements of training
@@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
}
-object DecisionTreeSuite extends FunSuite {
+object DecisionTreeSuite extends SparkFunSuite {
def validateClassifier(
model: DecisionTreeModel,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 55b0bac7d49fe..84dd3b342d4c0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.mllib.tree
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
@@ -32,7 +31,7 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[GradientBoostedTrees]].
*/
-class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
+class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
GradientBoostedTreesSuite.testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
index 92b498580af03..49aff21fe7914 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.tree
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
* Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
*/
-class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext {
test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 4ed66953cb628..e6df5d974bf36 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree
import scala.collection.mutable
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -35,7 +34,7 @@ import org.apache.spark.util.Utils
/**
* Test suite for [[RandomForest]].
*/
-class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
+class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
index b184e936672ca..9d756da410325 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -17,15 +17,14 @@
package org.apache.spark.mllib.tree.impl
-import org.scalatest.FunSuite
-
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
* Test suite for [[BaggedPoint]].
*/
-class BaggedPointSuite extends FunSuite with MLlibTestSparkContext {
+class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
test("BaggedPoint RDD: without subsampling") {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 668fc1d43c5d6..70219e9ad9d3e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -21,19 +21,19 @@ import java.io.File
import scala.io.Source
-import org.scalatest.FunSuite
-
import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files
+import org.apache.spark.SparkException
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
+class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
test("epsilon computation") {
assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
@@ -63,7 +63,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
val fastSquaredDist3 =
fastSquaredDistance(v2, norm2, v3, norm3, precision)
assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m")
- if (m > 10) {
+ if (m > 10) {
val v4 = Vectors.sparse(n, indices.slice(0, m - 10),
indices.map(i => a(i) + 0.5).slice(0, m - 10))
val norm4 = Vectors.norm(v4, 2.0)
@@ -109,6 +109,40 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
+ test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
+ val lines =
+ """
+ |0
+ |0 0:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
+ test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
+ val lines =
+ """
+ |0
+ |0 3:4.0 2:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Utils.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ intercept[SparkException] {
+ loadLibSVMFile(sc, path).collect()
+ }
+ Utils.deleteRecursively(tempDir)
+ }
+
test("saveAsLibSVMFile") {
val examples = sc.parallelize(Seq(
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
@@ -168,7 +202,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
"Each training+validation set combined should contain all of the data.")
}
// K fold cross validation should only have each element in the validation set exactly once
- assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted ===
+ assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted ===
data.collect().sorted)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
index f68fb95eac4e4..8dcb9ba9be108 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala
@@ -17,11 +17,9 @@
package org.apache.spark.mllib.util
-import org.scalatest.FunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.SparkException
-
-class NumericParserSuite extends FunSuite {
+class NumericParserSuite extends SparkFunSuite {
test("parser") {
val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)"
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
index 59e6c778806f4..8f475f30249d6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.util
+import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
-import org.scalatest.FunSuite
import org.apache.spark.mllib.util.TestingUtils._
import org.scalatest.exceptions.TestFailedException
-class TestingUtilsSuite extends FunSuite {
+class TestingUtilsSuite extends SparkFunSuite {
test("Comparing doubles using relative error.") {
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 0c3147761cfc5..a85e0a66f4a30 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
index 7dc7c65825e34..4b5bfcb6f04bc 100644
--- a/network/shuffle/pom.xml
+++ b/network/shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index 1e2e9c80af6cc..a99f7c4392d3d 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
diff --git a/pom.xml b/pom.xml
index c72d7cbf843ef..e9700a5d7b149 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOTpomSpark Project Parent POMhttp://spark.apache.org/
@@ -114,11 +114,10 @@
UTF-8UTF-8
- org.spark-project.akka
- 2.3.4-spark
- 1.6
+ com.typesafe.akka
+ 2.3.11
+ 1.7spark
- 2.0.10.21.1shaded-protobuf1.7.10
@@ -137,7 +136,7 @@
0.13.110.10.1.1
- 1.6.0rc3
+ 1.7.01.2.48.1.14.v201310313.0.0.v201112011016
@@ -180,7 +179,7 @@
compile${session.executionRootDirectory}
@@ -269,6 +268,18 @@
false
+
+
+ spark-1.4-staging
+ Spark 1.4 RC4 Staging Repository
+ https://repository.apache.org/content/repositories/orgapachespark-1112
+
+ true
+
+
+ false
+
+
@@ -576,7 +587,7 @@
io.nettynetty-all
- 4.0.23.Final
+ 4.0.28.Finalorg.apache.derby
@@ -1069,13 +1080,13 @@
- com.twitter
+ org.apache.parquetparquet-column${parquet.version}${parquet.deps.scope}
- com.twitter
+ org.apache.parquetparquet-hadoop${parquet.version}${parquet.deps.scope}
@@ -1205,15 +1216,6 @@
-target${java.version}
-
-
-
- org.scalamacros
- paradise_${scala.version}
- ${scala.macros.version}
-
-
@@ -1252,7 +1254,9 @@
${test.java.home}
+ testtrue
+ ${project.build.directory}/tmp${spark.test.home}1false
@@ -1284,7 +1288,9 @@
${test.java.home}
+ testtrue
+ ${project.build.directory}/tmp${spark.test.home}1false
@@ -1426,6 +1432,8 @@
2.3false
+
+ false
@@ -1542,6 +1550,26 @@
+
+
+ org.apache.maven.plugins
+ maven-antrun-plugin
+
+
+ create-tmp-dir
+ generate-test-resources
+
+ run
+
+
+
+
+
+
+
+
+
+
org.apache.maven.plugins
@@ -1664,6 +1692,8 @@
0.98.7-hadoop1hadoop11.8.8
+ org.spark-project.akka
+ 2.3.4-spark
@@ -1753,22 +1783,6 @@
sql/hive-thriftserver
-
- hive-0.12.0
-
- 0.12.0-protobuf-2.5
- 0.12.0
- 10.4.2.0
-
-
-
- hive-0.13.1
-
- 0.13.1a
- 0.13.1
- 10.10.1.1
-
- scala-2.10
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index dde92949fa175..5812b72f0aa78 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -91,7 +91,8 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "1.3.0"
+ // TODO: Change this once Spark 1.4.0 is released
+ val previousSparkVersion = "1.4.0-rc4"
val fullId = "spark-" + projectRef.project + "_2.10"
mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 11b439e7875fc..8a93ca2999510 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -34,10 +34,31 @@ import com.typesafe.tools.mima.core.ProblemFilters._
object MimaExcludes {
def excludes(version: String) =
version match {
+ case v if v.startsWith("1.5") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("deploy"),
+ // These are needed if checking against the sbt build, since they are part of
+ // the maven-generated artifacts in 1.3.
+ excludePackage("org.spark-project.jetty"),
+ MimaBuild.excludeSparkPackage("unused"),
+ // JavaRDDLike is not meant to be extended by user programs
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.partitioner"),
+ // Mima false positive (was a private[spark] class)
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.util.collection.PairIterator"),
+ // Removing a testing method from a private class
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
+ // SQL execution is considered private.
+ excludePackage("org.apache.spark.sql.execution")
+ )
case v if v.startsWith("1.4") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
MimaBuild.excludeSparkPackage("ml"),
+ // SPARK-7910 Adding a method to get the partioner to JavaRDD,
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"),
// SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
// These are needed if checking against the sbt build, since they are part of
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b9515a12bc573..d7e374558c5e2 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -23,7 +23,6 @@ import scala.collection.JavaConversions._
import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
-import sbtunidoc.Plugin.genjavadocSettings
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys}
import net.virtualvoid.sbt.graph.Plugin.graphSettings
@@ -52,6 +51,11 @@ object BuildCommons {
// Root project.
val spark = ProjectRef(buildLocation, "spark")
val sparkHome = buildLocation
+
+ val testTempDir = s"$sparkHome/target/tmp"
+ if (!new File(testTempDir).isDirectory()) {
+ require(new File(testTempDir).mkdirs())
+ }
}
object SparkBuild extends PomBuild {
@@ -118,7 +122,12 @@ object SparkBuild extends PomBuild {
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
- lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq (
+ lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq(
+ libraryDependencies += compilerPlugin(
+ "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full),
+ scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java")))
+
+ lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq (
javaHome := sys.env.get("JAVA_HOME")
.orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() })
.map(file),
@@ -126,7 +135,7 @@ object SparkBuild extends PomBuild {
retrieveManaged := true,
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
publishMavenStyle := true,
- unidocGenjavadocVersion := "0.8",
+ unidocGenjavadocVersion := "0.9-spark0",
resolvers += Resolver.mavenLocal,
otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))),
@@ -174,9 +183,6 @@ object SparkBuild extends PomBuild {
/* Enable unidoc only for the root spark project */
enable(Unidoc.settings)(spark)
- /* Catalyst macro settings */
- enable(Catalyst.settings)(catalyst)
-
/* Spark SQL Core console settings */
enable(SQL.settings)(sql)
@@ -271,14 +277,6 @@ object OldDeps {
)
}
-object Catalyst {
- lazy val settings = Seq(
- addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full),
- // Quasiquotes break compiling scala doc...
- // TODO: Investigate fixing this.
- sources in (Compile, doc) ~= (_ filter (_.getName contains "codegen")))
-}
-
object SQL {
lazy val settings = Seq(
initialCommands in console :=
@@ -503,6 +501,7 @@ object TestSettings {
"SPARK_DIST_CLASSPATH" ->
(fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
"JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))),
+ javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir",
javaOptions in Test += "-Dspark.test.home=" + sparkHome,
javaOptions in Test += "-Dspark.testing=1",
javaOptions in Test += "-Dspark.port.maxRetries=100",
@@ -511,6 +510,7 @@ object TestSettings {
javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
+ javaOptions in Test += "-Dderby.system.durability=test",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
javaOptions in Test += "-ea",
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 7096b0d3ee7de..75bd604a1b857 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -25,7 +25,7 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1")
-addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1")
+addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3")
addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2")
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 0d21a132048a5..adca90ddaf397 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -261,3 +261,7 @@ def _start_update_server():
thread.daemon = True
thread.start()
return server
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index aeb7ad4f2f83e..44d90f1437bc9 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -324,10 +324,12 @@ def stop(self):
with SparkContext._lock:
SparkContext._active_spark_context = None
- def range(self, start, end, step=1, numSlices=None):
+ def range(self, start, end=None, step=1, numSlices=None):
"""
Create a new RDD of int containing elements from `start` to `end`
- (exclusive), increased by `step` every element.
+ (exclusive), increased by `step` every element. Can be called the same
+ way as python's built-in range() function. If called with a single argument,
+ the argument is interpreted as `end`, and `start` is set to 0.
:param start: the start value
:param end: the end value (exclusive)
@@ -335,9 +337,17 @@ def range(self, start, end, step=1, numSlices=None):
:param numSlices: the number of partitions of the new RDD
:return: An RDD of int
+ >>> sc.range(5).collect()
+ [0, 1, 2, 3, 4]
+ >>> sc.range(2, 4).collect()
+ [2, 3]
>>> sc.range(1, 7, 2).collect()
[1, 3, 5]
"""
+ if end is None:
+ end = start
+ start = 0
+
return self.parallelize(xrange(start, end, step), numSlices)
def parallelize(self, c, numSlices=None):
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 23c37167b3711..d8ddb78c6d639 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -205,7 +205,7 @@ def getMetricName(self):
def setParams(self, predictionCol="prediction", labelCol="label",
metricName="rmse"):
"""
- setParams(self, predictionCol="prediction", labelCol="label",
+ setParams(self, predictionCol="prediction", labelCol="label", \
metricName="rmse")
Sets params for regression evaluator.
"""
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index b0479d9b074db..ddb33f427ac64 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -324,65 +324,73 @@ def getP(self):
@inherit_doc
class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
"""
- A one-hot encoder that maps a column of label indices to a column of binary vectors, with
- at most a single one-value. By default, the binary vector has an element for each category, so
- with 5 categories, an input value of 2.0 would map to an output vector of
- (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so
- the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
- of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
- linearly dependent because they sum up to one.
-
- TODO: This method requires the use of StringIndexer first. Decouple them.
+ A one-hot encoder that maps a column of category indices to a
+ column of binary vectors, with at most a single one-value per row
+ that indicates the input category index.
+ For example with 5 categories, an input value of 2.0 would map to
+ an output vector of `[0.0, 0.0, 1.0, 0.0]`.
+ The last category is not included by default (configurable via
+ :py:attr:`dropLast`) because it makes the vector entries sum up to
+ one, and hence linearly dependent.
+ So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
+ Note that this is different from scikit-learn's OneHotEncoder,
+ which keeps all categories.
+ The output vectors are sparse.
+
+ .. seealso::
+
+ :py:class:`StringIndexer` for converting categorical values into
+ category indices
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> model = stringIndexer.fit(stringIndDf)
>>> td = model.transform(stringIndDf)
- >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features")
+ >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features")
>>> encoder.transform(td).head().features
- SparseVector(2, {})
+ SparseVector(2, {0: 1.0})
>>> encoder.setParams(outputCol="freqs").transform(td).head().freqs
- SparseVector(2, {})
- >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"}
+ SparseVector(2, {0: 1.0})
+ >>> params = {encoder.dropLast: False, encoder.outputCol: "test"}
>>> encoder.transform(td, params).head().test
SparseVector(3, {0: 1.0})
"""
# a placeholder to make it appear in the generated doc
- includeFirst = Param(Params._dummy(), "includeFirst", "include first category")
+ dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category")
@keyword_only
- def __init__(self, includeFirst=True, inputCol=None, outputCol=None):
+ def __init__(self, dropLast=True, inputCol=None, outputCol=None):
"""
__init__(self, includeFirst=True, inputCol=None, outputCol=None)
"""
super(OneHotEncoder, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
- self.includeFirst = Param(self, "includeFirst", "include first category")
- self._setDefault(includeFirst=True)
+ self.dropLast = Param(self, "dropLast", "whether to drop the last category")
+ self._setDefault(dropLast=True)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
- def setParams(self, includeFirst=True, inputCol=None, outputCol=None):
+ def setParams(self, dropLast=True, inputCol=None, outputCol=None):
"""
- setParams(self, includeFirst=True, inputCol=None, outputCol=None)
+ setParams(self, dropLast=True, inputCol=None, outputCol=None)
Sets params for this OneHotEncoder.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
- def setIncludeFirst(self, value):
+ def setDropLast(self, value):
"""
- Sets the value of :py:attr:`includeFirst`.
+ Sets the value of :py:attr:`dropLast`.
"""
- self._paramMap[self.includeFirst] = value
+ self._paramMap[self.dropLast] = value
return self
- def getIncludeFirst(self):
+ def getDropLast(self):
"""
- Gets the value of includeFirst or its default value.
+ Gets the value of dropLast or its default value.
"""
- return self.getOrDefault(self.includeFirst)
+ return self.getOrDefault(self.dropLast)
@inherit_doc
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index b3e0dd7abf681..b06099ac0aee6 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
indicated user preferences rather than explicit ratings given to
items.
+ >>> df = sqlContext.createDataFrame(
+ ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
+ ... ["user", "item", "rating"])
>>> als = ALS(rank=10, maxIter=5)
>>> model = als.fit(df)
+ >>> model.rank
+ 10
+ >>> model.userFactors.orderBy("id").collect()
+ [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
>>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
>>> predictions[0]
@@ -260,6 +267,27 @@ class ALSModel(JavaModel):
Model fitted by ALS.
"""
+ @property
+ def rank(self):
+ """rank of the matrix factorization model"""
+ return self._call_java("rank")
+
+ @property
+ def userFactors(self):
+ """
+ a DataFrame that stores user factors in two columns: `id` and
+ `features`
+ """
+ return self._call_java("userFactors")
+
+ @property
+ def itemFactors(self):
+ """
+ a DataFrame that stores item factors in two columns: `id` and
+ `features`
+ """
+ return self._call_java("itemFactors")
+
if __name__ == "__main__":
import doctest
@@ -272,8 +300,6 @@ class ALSModel(JavaModel):
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
- (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 497841b6c8ce6..0bf988fd72f14 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -91,20 +91,19 @@ class CrossValidator(Estimator):
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.mllib.linalg import Vectors
>>> dataset = sqlContext.createDataFrame(
- ... [(Vectors.dense([0.0, 1.0]), 0.0),
- ... (Vectors.dense([1.0, 2.0]), 1.0),
- ... (Vectors.dense([0.55, 3.0]), 0.0),
- ... (Vectors.dense([0.45, 4.0]), 1.0),
- ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
+ ... [(Vectors.dense([0.0]), 0.0),
+ ... (Vectors.dense([0.4]), 1.0),
+ ... (Vectors.dense([0.5]), 0.0),
+ ... (Vectors.dense([0.6]), 1.0),
+ ... (Vectors.dense([1.0]), 1.0)] * 10,
... ["features", "label"])
>>> lr = LogisticRegression()
- >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
+ >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- >>> # SPARK-7432: The following test is flaky.
- >>> # cvModel = cv.fit(dataset)
- >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
- >>> # cvModel.transform(dataset).collect() == expected.collect()
+ >>> cvModel = cv.fit(dataset)
+ >>> evaluator.evaluate(cvModel.transform(dataset))
+ 0.8333...
"""
# a placeholder to make it appear in the generated doc
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index 07507b2ad0d05..acba3a717d21a 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -23,16 +23,10 @@
# MLlib currently needs NumPy 1.4+, so complain if lower
import numpy
-if numpy.version.version < '1.4':
+
+ver = [int(x) for x in numpy.version.version.split('.')[:2]]
+if ver < [1, 4]:
raise Exception("MLlib requires NumPy 1.4+")
__all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random',
'recommendation', 'regression', 'stat', 'tree', 'util']
-
-import sys
-from . import rand as random
-modname = __name__ + '.random'
-random.__name__ = modname
-random.RandomRDDs.__module__ = modname
-sys.modules[modname] = random
-del modname, sys
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index ba6058978880a..855e85f57155e 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -27,7 +27,7 @@
from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-
+from pyspark.sql import DataFrame, SQLContext
# Hack for support float('inf') in Py4j
_old_smart_decode = py4j.protocol.smart_decode
@@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"):
jrdd = sc._jvm.SerDe.javaToPython(r)
return RDD(jrdd, sc)
+ if clsName == 'DataFrame':
+ return DataFrame(r, SQLContext(sc))
+
if clsName in _picklable_classes:
r = sc._jvm.SerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)):
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index aab5e5f4b77b5..c5cf3a4e7ff22 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
Evaluator for binary classification.
+ :param scoreAndLabels: an RDD of (score, label) pairs
+
>>> scoreAndLabels = sc.parallelize([
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
>>> metrics = BinaryClassificationMetrics(scoreAndLabels)
@@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
def __init__(self, scoreAndLabels):
- """
- :param scoreAndLabels: an RDD of (score, label) pairs
- """
sc = scoreAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
@@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper):
"""
Evaluator for regression.
+ :param predictionAndObservations: an RDD of (prediction,
+ observation) pairs.
+
>>> predictionAndObservations = sc.parallelize([
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
>>> metrics = RegressionMetrics(predictionAndObservations)
@@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndObservations):
- """
- :param predictionAndObservations: an RDD of (prediction, observation) pairs.
- """
sc = predictionAndObservations.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
@@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.
+ :param predictionAndLabels an RDD of (prediction, label) pairs.
+
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
@@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels an RDD of (prediction, label) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
@@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper):
"""
Evaluator for ranking algorithms.
+ :param predictionAndLabels: an RDD of (predicted ranking,
+ ground truth set) pairs.
+
>>> predictionAndLabels = sc.parallelize([
... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
@@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
@@ -347,6 +345,10 @@ class MultilabelMetrics(JavaModelWrapper):
"""
Evaluator for multilabel classification.
+ :param predictionAndLabels: an RDD of (predictions, labels) pairs,
+ both are non-null Arrays, each with
+ unique elements.
+
>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index aac305db6c19a..da90554f41437 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -68,6 +68,8 @@ class Normalizer(VectorTransformer):
For `p` = float('inf'), max(abs(vector)) will be used as norm for
normalization.
+ :param p: Normalization in L^p^ space, p = 2 by default.
+
>>> v = Vectors.dense(range(3))
>>> nor = Normalizer(1)
>>> nor.transform(v)
@@ -82,9 +84,6 @@ class Normalizer(VectorTransformer):
DenseVector([0.0, 0.5, 1.0])
"""
def __init__(self, p=2.0):
- """
- :param p: Normalization in L^p^ space, p = 2 by default.
- """
assert p >= 1.0, "p should be greater than 1.0"
self.p = float(p)
@@ -94,7 +93,7 @@ def transform(self, vector):
:param vector: vector or RDD of vector to be normalized.
:return: normalized vector. If the norm of the input is zero, it
- will return the input vector.
+ will return the input vector.
"""
sc = SparkContext._active_spark_context
assert sc is not None, "SparkContext should be initialized first"
@@ -164,6 +163,13 @@ class StandardScaler(object):
variance using column summary statistics on the samples in the
training set.
+ :param withMean: False by default. Centers the data with mean
+ before scaling. It will build a dense output, so this
+ does not work on sparse input and will raise an
+ exception.
+ :param withStd: True by default. Scales the data to unit
+ standard deviation.
+
>>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])]
>>> dataset = sc.parallelize(vs)
>>> standardizer = StandardScaler(True, True)
@@ -174,14 +180,6 @@ class StandardScaler(object):
DenseVector([0.7071, -0.7071, 0.7071])
"""
def __init__(self, withMean=False, withStd=True):
- """
- :param withMean: False by default. Centers the data with mean
- before scaling. It will build a dense output, so this
- does not work on sparse input and will raise an
- exception.
- :param withStd: True by default. Scales the data to unit
- standard deviation.
- """
if not (withMean or withStd):
warnings.warn("Both withMean and withStd are false. The model does nothing.")
self.withMean = withMean
@@ -193,7 +191,7 @@ def fit(self, dataset):
for later scaling.
:param data: The data used to compute the mean and variance
- to build the transformation model.
+ to build the transformation model.
:return: a StandardScalarModel
"""
dataset = dataset.map(_convert_to_vector)
@@ -223,6 +221,8 @@ class ChiSqSelector(object):
Creates a ChiSquared feature selector.
+ :param numTopFeatures: number of features that selector will select.
+
>>> data = [
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
@@ -236,9 +236,6 @@ class ChiSqSelector(object):
DenseVector([5.0])
"""
def __init__(self, numTopFeatures):
- """
- :param numTopFeatures: number of features that selector will select.
- """
self.numTopFeatures = int(numTopFeatures)
def fit(self, data):
@@ -246,9 +243,9 @@ def fit(self, data):
Returns a ChiSquared feature selector.
:param data: an `RDD[LabeledPoint]` containing the labeled dataset
- with categorical features. Real-valued features will be
- treated as categorical for each distinct value.
- Apply feature discretizer before using this function.
+ with categorical features. Real-valued features will be
+ treated as categorical for each distinct value.
+ Apply feature discretizer before using this function.
"""
jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
return ChiSqSelectorModel(jmodel)
@@ -263,15 +260,14 @@ class HashingTF(object):
Note: the terms must be hashable (can not be dict/set/list...).
+ :param numFeatures: number of features (default: 2^20)
+
>>> htf = HashingTF(100)
>>> doc = "a a b b c d".split(" ")
>>> htf.transform(doc)
SparseVector(100, {...})
"""
def __init__(self, numFeatures=1 << 20):
- """
- :param numFeatures: number of features (default: 2^20)
- """
self.numFeatures = numFeatures
def indexOf(self, term):
@@ -311,7 +307,7 @@ def transform(self, x):
Call transform directly on the RDD instead.
:param x: an RDD of term frequency vectors or a term frequency
- vector
+ vector
:return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
if isinstance(x, RDD):
@@ -342,6 +338,9 @@ class IDF(object):
`minDocFreq`). For terms that are not in at least `minDocFreq`
documents, the IDF is found as 0, resulting in TF-IDFs of 0.
+ :param minDocFreq: minimum of documents in which a term
+ should appear for filtering
+
>>> n = 4
>>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)),
... Vectors.dense([0.0, 1.0, 2.0, 3.0]),
@@ -362,10 +361,6 @@ class IDF(object):
SparseVector(4, {1: 0.0, 3: 0.5754})
"""
def __init__(self, minDocFreq=0):
- """
- :param minDocFreq: minimum of documents in which a term
- should appear for filtering
- """
self.minDocFreq = minDocFreq
def fit(self, dataset):
diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py
similarity index 100%
rename from python/pyspark/mllib/rand.py
rename to python/pyspark/mllib/random.py
diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py
new file mode 100644
index 0000000000000..7da921976d4d2
--- /dev/null
+++ b/python/pyspark/mllib/stat/KernelDensity.py
@@ -0,0 +1,61 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+
+if sys.version > '3':
+ xrange = range
+
+import numpy as np
+
+from pyspark.mllib.common import callMLlibFunc
+from pyspark.rdd import RDD
+
+
+class KernelDensity(object):
+ """
+ .. note:: Experimental
+
+ Estimate probability density at required points given a RDD of samples
+ from the population.
+
+ >>> kd = KernelDensity()
+ >>> sample = sc.parallelize([0.0, 1.0])
+ >>> kd.setSample(sample)
+ >>> kd.estimate([0.0, 1.0])
+ array([ 0.12938758, 0.12938758])
+ """
+ def __init__(self):
+ self._bandwidth = 1.0
+ self._sample = None
+
+ def setBandwidth(self, bandwidth):
+ """Set bandwidth of each sample. Defaults to 1.0"""
+ self._bandwidth = bandwidth
+
+ def setSample(self, sample):
+ """Set sample points from the population. Should be a RDD"""
+ if not isinstance(sample, RDD):
+ raise TypeError("samples should be a RDD, received %s" % type(sample))
+ self._sample = sample
+
+ def estimate(self, points):
+ """Estimate the probability density at points"""
+ points = list(points)
+ densities = callMLlibFunc(
+ "estimateKernelDensity", self._sample, self._bandwidth, points)
+ return np.asarray(densities)
diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py
index e3e128513e0d7..c8a721d3fe41c 100644
--- a/python/pyspark/mllib/stat/__init__.py
+++ b/python/pyspark/mllib/stat/__init__.py
@@ -22,6 +22,7 @@
from pyspark.mllib.stat._statistics import *
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.stat.test import ChiSqTestResult
+from pyspark.mllib.stat.KernelDensity import KernelDensity
__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult",
- "MultivariateGaussian"]
+ "MultivariateGaussian", "KernelDensity"]
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 1d0b16cade8bb..81c420ce16541 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -362,7 +362,7 @@ def _spill(self):
self.spills += 1
gc.collect() # release the memory as much as possible
- MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+ MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
def items(self):
""" Return all merged items as iterator """
@@ -515,7 +515,7 @@ def load(f):
gc.collect()
batch //= 2
limit = self._next_limit()
- MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+ MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
DiskBytesSpilled += os.path.getsize(path)
os.unlink(path) # data will be deleted after close
@@ -630,7 +630,7 @@ def _spill(self):
self.values = []
gc.collect()
DiskBytesSpilled += self._file.tell() - pos
- MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+ MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
class ExternalListOfList(ExternalList):
@@ -794,7 +794,7 @@ def _spill(self):
self.spills += 1
gc.collect() # release the memory as much as possible
- MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+ MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
def _merged_items(self, index):
size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 8fee92ae3aed5..ad9c891ba1c04 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -45,22 +45,19 @@
def since(version):
+ """
+ A decorator that annotates a function to append the version of Spark the function was added.
+ """
+ import re
+ indent_p = re.compile(r'\n( +)')
+
def deco(f):
- f.__doc__ = f.__doc__.rstrip() + "\n\n.. versionadded:: %s" % version
+ indents = indent_p.findall(f.__doc__)
+ indent = ' ' * (min(len(m) for m in indents) if indents else 0)
+ f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
return f
return deco
-# fix the module name conflict for Python 3+
-import sys
-from . import _types as types
-modname = __name__ + '.types'
-types.__name__ = modname
-# update the __module__ for all objects, make them picklable
-for v in types.__dict__.values():
- if hasattr(v, "__module__") and v.__module__.endswith('._types'):
- v.__module__ = modname
-sys.modules[modname] = types
-del modname, sys
from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
@@ -70,7 +67,9 @@ def deco(f):
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec
+
__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
+ 'DataFrameReader', 'DataFrameWriter'
]
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 8dc5039f587f0..1ecec5b126505 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -315,6 +315,14 @@ def between(self, lowerBound, upperBound):
"""
A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
+
+ >>> df.select(df.name, df.age.between(2, 4)).show()
+ +-----+--------------------------+
+ | name|((age >= 2) && (age <= 4))|
+ +-----+--------------------------+
+ |Alice| true|
+ | Bob| false|
+ +-----+--------------------------+
"""
return (self >= lowerBound) & (self <= upperBound)
@@ -328,12 +336,20 @@ def when(self, condition, value):
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+ +-----+--------------------------------------------------------+
+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
+ +-----+--------------------------------------------------------+
+ |Alice| -1|
+ | Bob| 1|
+ +-----+--------------------------------------------------------+
"""
- sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
- jc = sc._jvm.functions.when(condition._jc, v)
+ jc = self._jc.when(condition._jc, v)
return Column(jc)
@since(1.4)
@@ -345,9 +361,18 @@ def otherwise(self, value):
See :func:`pyspark.sql.functions.when` for example usage.
:param value: a literal value, or a :class:`Column` expression.
+
+ >>> from pyspark.sql import functions as F
+ >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
+ +-----+---------------------------------+
+ | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
+ +-----+---------------------------------+
+ |Alice| 0|
+ | Bob| 1|
+ +-----+---------------------------------+
"""
v = value._jc if isinstance(value, Column) else value
- jc = self._jc.otherwise(value)
+ jc = self._jc.otherwise(v)
return Column(jc)
@since(1.4)
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 22f6257dfe02d..599c9ac5794a2 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -124,11 +124,14 @@ def getConf(self, key, defaultValue):
@property
@since("1.3.1")
def udf(self):
- """Returns a :class:`UDFRegistration` for UDF registration."""
+ """Returns a :class:`UDFRegistration` for UDF registration.
+
+ :return: :class:`UDFRegistration`
+ """
return UDFRegistration(self)
@since(1.4)
- def range(self, start, end, step=1, numPartitions=None):
+ def range(self, start, end=None, step=1, numPartitions=None):
"""
Create a :class:`DataFrame` with single LongType column named `id`,
containing elements in a range from `start` to `end` (exclusive) with
@@ -138,14 +141,24 @@ def range(self, start, end, step=1, numPartitions=None):
:param end: the end value (exclusive)
:param step: the incremental step (default: 1)
:param numPartitions: the number of partitions of the DataFrame
- :return: A new DataFrame
+ :return: :class:`DataFrame`
>>> sqlContext.range(1, 7, 2).collect()
[Row(id=1), Row(id=3), Row(id=5)]
+
+ If only one argument is specified, it will be used as the end value.
+
+ >>> sqlContext.range(3).collect()
+ [Row(id=0), Row(id=1), Row(id=2)]
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
- jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
+ if end is None:
+ jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
+ else:
+ jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
return DataFrame(jdf, self)
@ignore_unicode_prefix
@@ -195,8 +208,8 @@ def _inferSchema(self, rdd, samplingRatio=None):
raise ValueError("The first row in RDD is empty, "
"can not infer schema")
if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated,"
- "please use pyspark.sql.Row instead")
+ warnings.warn("Using RDD of dict to inferSchema is deprecated. "
+ "Use pyspark.sql.Row instead")
if samplingRatio is None:
schema = _infer_schema(first)
@@ -219,7 +232,7 @@ def inferSchema(self, rdd, samplingRatio=None):
"""
.. note:: Deprecated in 1.3, use :func:`createDataFrame` instead.
"""
- warnings.warn("inferSchema is deprecated, please use createDataFrame instead")
+ warnings.warn("inferSchema is deprecated, please use createDataFrame instead.")
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
@@ -262,6 +275,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
:class:`list`, or :class:`pandas.DataFrame`.
:param schema: a :class:`StructType` or list of column names. default None.
:param samplingRatio: the sample ratio of rows used for inferring
+ :return: :class:`DataFrame`
>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
@@ -359,18 +373,15 @@ def registerDataFrameAsTable(self, df, tableName):
else:
raise ValueError("Can only register DataFrame as table")
- @since(1.0)
def parquetFile(self, *paths):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlContext.parquetFile(parquetFile)
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
+ .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead.
+
+ >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
+ warnings.warn("parquetFile is deprecated. Use read.parquet() instead.")
gateway = self._sc._gateway
jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths))
for i in range(0, len(paths)):
@@ -378,39 +389,15 @@ def parquetFile(self, *paths):
jdf = self._ssql_ctx.parquetFile(jpaths)
return DataFrame(jdf, self)
- @since(1.0)
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""Loads a text file storing one JSON object per line as a :class:`DataFrame`.
- If the schema is provided, applies the given schema to this JSON dataset.
- Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead.
- >>> import tempfile, shutil
- >>> jsonFile = tempfile.mkdtemp()
- >>> shutil.rmtree(jsonFile)
- >>> with open(jsonFile, 'w') as f:
- ... f.writelines(jsonStrings)
- >>> df1 = sqlContext.jsonFile(jsonFile)
- >>> df1.printSchema()
- root
- |-- field1: long (nullable = true)
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field4: long (nullable = true)
-
- >>> from pyspark.sql.types import *
- >>> schema = StructType([
- ... StructField("field2", StringType()),
- ... StructField("field3",
- ... StructType([StructField("field5", ArrayType(IntegerType()))]))])
- >>> df2 = sqlContext.jsonFile(jsonFile, schema)
- >>> df2.printSchema()
- root
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field5: array (nullable = true)
- | | |-- element: integer (containsNull = true)
+ >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes
+ [('age', 'bigint'), ('name', 'string')]
"""
+ warnings.warn("jsonFile is deprecated. Use read.json() instead.")
if schema is None:
df = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
@@ -462,21 +449,16 @@ def func(iterator):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)
- @since(1.3)
def load(self, path=None, source=None, schema=None, **options):
"""Returns the dataset in a data source as a :class:`DataFrame`.
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Optionally, a schema can be provided as the schema of the returned DataFrame.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead.
"""
+ warnings.warn("load is deprecated. Use read.load() instead.")
return self.read.load(path, source, schema, **options)
@since(1.3)
- def createExternalTable(self, tableName, path=None, source=None,
- schema=None, **options):
+ def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
"""Creates an external table based on the dataset in a data source.
It returns the DataFrame associated with the external table.
@@ -487,6 +469,8 @@ def createExternalTable(self, tableName, path=None, source=None,
Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
created external table.
+
+ :return: :class:`DataFrame`
"""
if path is not None:
options["path"] = path
@@ -508,6 +492,8 @@ def createExternalTable(self, tableName, path=None, source=None,
def sql(self, sqlQuery):
"""Returns a :class:`DataFrame` representing the result of the given query.
+ :return: :class:`DataFrame`
+
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
@@ -519,6 +505,8 @@ def sql(self, sqlQuery):
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.
+ :return: :class:`DataFrame`
+
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -536,6 +524,9 @@ def tables(self, dbName=None):
The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
(a column with :class:`BooleanType` indicating if a table is a temporary one or not).
+ :param dbName: string, name of the database to use.
+ :return: :class:`DataFrame`
+
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.tables()
>>> df2.filter("tableName = 'table1'").first()
@@ -550,7 +541,8 @@ def tables(self, dbName=None):
def tableNames(self, dbName=None):
"""Returns a list of names of tables in the database ``dbName``.
- If ``dbName`` is not specified, the current database will be used.
+ :param dbName: string, name of the database to use. Default to the current database.
+ :return: list of table names, in string
>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> "table1" in sqlContext.tableNames()
@@ -585,8 +577,7 @@ def read(self):
Returns a :class:`DataFrameReader` that can be used to read data
in as a :class:`DataFrame`.
- >>> sqlContext.read
-
+ :return: :class:`DataFrameReader`
"""
return DataFrameReader(self)
@@ -644,10 +635,14 @@ def register(self, name, f, returnType=StringType()):
def _test():
+ import os
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.context
+
+ os.chdir(os.environ["SPARK_HOME"])
+
globs = pyspark.sql.context.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 936487519a645..9615e576497cd 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -22,6 +22,7 @@
if sys.version >= '3':
basestring = unicode = str
long = int
+ from functools import reduce
else:
from itertools import imap as map
@@ -44,7 +45,7 @@ class DataFrame(object):
A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
and can be created using various functions in :class:`SQLContext`::
- people = sqlContext.parquetFile("...")
+ people = sqlContext.read.parquet("...")
Once created, it can be manipulated using the various domain-specific-language
(DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
@@ -56,8 +57,8 @@ class DataFrame(object):
A more concrete example::
# To create DataFrame using SQLContext
- people = sqlContext.parquetFile("...")
- department = sqlContext.parquetFile("...")
+ people = sqlContext.read.parquet("...")
+ department = sqlContext.read.parquet("...")
people.filter(people.age > 30).join(department, people.deptId == department.id)) \
.groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
@@ -120,21 +121,12 @@ def toJSON(self, use_unicode=True):
rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
- @since(1.3)
def saveAsParquetFile(self, path):
"""Saves the contents as a Parquet file, preserving the schema.
- Files that are written out using this method can be read back in as
- a :class:`DataFrame` using :func:`SQLContext.parquetFile`.
-
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlContext.parquetFile(parquetFile)
- >>> sorted(df2.collect()) == sorted(df.collect())
- True
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead.
"""
+ warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.")
self._jdf.saveAsParquetFile(path)
@since(1.3)
@@ -151,69 +143,45 @@ def registerTempTable(self, name):
"""
self._jdf.registerTempTable(name)
- @since(1.3)
def registerAsTable(self, name):
- """DEPRECATED: use :func:`registerTempTable` instead"""
- warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning)
+ """
+ .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead.
+ """
+ warnings.warn("Use registerTempTable instead of registerAsTable.")
self.registerTempTable(name)
- @since(1.3)
def insertInto(self, tableName, overwrite=False):
"""Inserts the contents of this :class:`DataFrame` into the specified table.
- Optionally overwriting any existing data.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead.
"""
+ warnings.warn("insertInto is deprecated. Use write.insertInto() instead.")
self.write.insertInto(tableName, overwrite)
- @since(1.3)
def saveAsTable(self, tableName, source=None, mode="error", **options):
"""Saves the contents of this :class:`DataFrame` to a data source as a table.
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Additionally, mode is used to specify the behavior of the saveAsTable operation when
- table already exists in the data source. There are four modes:
-
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead.
"""
+ warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.")
self.write.saveAsTable(tableName, source, mode, **options)
@since(1.3)
def save(self, path=None, source=None, mode="error", **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
- The data source is specified by the ``source`` and a set of ``options``.
- If ``source`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
-
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead.
"""
+ warnings.warn("insertInto is deprecated. Use write.save() instead.")
return self.write.save(path, source, mode, **options)
@property
@since(1.4)
def write(self):
"""
- Interface for saving the content of the :class:`DataFrame` out
- into external storage.
-
- :return :class:`DataFrameWriter`
+ Interface for saving the content of the :class:`DataFrame` out into external storage.
- .. note:: Experimental
-
- >>> df.write
-
+ :return: :class:`DataFrameWriter`
"""
return DataFrameWriter(self)
@@ -536,36 +504,52 @@ def alias(self, alias):
@ignore_unicode_prefix
@since(1.3)
- def join(self, other, joinExprs=None, joinType=None):
+ def join(self, other, on=None, how=None):
"""Joins with another :class:`DataFrame`, using the given join expression.
The following performs a full outer join between ``df1`` and ``df2``.
:param other: Right side of the join
- :param joinExprs: a string for join column name, or a join expression (Column).
- If joinExprs is a string indicating the name of the join column,
- the column must exist on both sides, and this performs an inner equi-join.
- :param joinType: str, default 'inner'.
+ :param on: a string for join column name, a list of column names,
+ , a join expression (Column) or a list of Columns.
+ If `on` is a string or a list of string indicating the name of the join column(s),
+ the column(s) must exist on both sides, and this performs an inner equi-join.
+ :param how: str, default 'inner'.
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+ >>> cond = [df.name == df3.name, df.age == df3.age]
+ >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
+ [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]
+
>>> df.join(df2, 'name').select(df.name, df2.height).collect()
[Row(name=u'Bob', height=85)]
+
+ >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect()
+ [Row(name=u'Bob', age=5)]
"""
- if joinExprs is None:
+ if on is not None and not isinstance(on, list):
+ on = [on]
+
+ if on is None or len(on) == 0:
jdf = self._jdf.join(other._jdf)
- elif isinstance(joinExprs, basestring):
- jdf = self._jdf.join(other._jdf, joinExprs)
+
+ if isinstance(on[0], basestring):
+ jdf = self._jdf.join(other._jdf, self._jseq(on))
else:
- assert isinstance(joinExprs, Column), "joinExprs should be Column"
- if joinType is None:
- jdf = self._jdf.join(other._jdf, joinExprs._jc)
+ assert isinstance(on[0], Column), "on should be Column or list of Column"
+ if len(on) > 1:
+ on = reduce(lambda x, y: x.__and__(y), on)
+ else:
+ on = on[0]
+ if how is None:
+ jdf = self._jdf.join(other._jdf, on._jc, "inner")
else:
- assert isinstance(joinType, basestring), "joinType should be basestring"
- jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+ assert isinstance(how, basestring), "how should be basestring"
+ jdf = self._jdf.join(other._jdf, on._jc, how)
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@@ -636,6 +620,9 @@ def describe(self, *cols):
This include count, mean, stddev, min, and max. If no columns are
given, this function computes statistics for all numerical columns.
+ .. note:: This function is meant for exploratory data analysis, as we make no \
+ guarantee about the backward compatibility of the schema of the resulting DataFrame.
+
>>> df.describe().show()
+-------+---+
|summary|age|
@@ -646,16 +633,30 @@ def describe(self, *cols):
| min| 2|
| max| 5|
+-------+---+
+ >>> df.describe(['age', 'name']).show()
+ +-------+---+-----+
+ |summary|age| name|
+ +-------+---+-----+
+ | count| 2| 2|
+ | mean|3.5| null|
+ | stddev|1.5| null|
+ | min| 2|Alice|
+ | max| 5| Bob|
+ +-------+---+-----+
"""
+ if len(cols) == 1 and isinstance(cols[0], list):
+ cols = cols[0]
jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@since(1.3)
def head(self, n=None):
- """
- Returns the first ``n`` rows as a list of :class:`Row`,
- or the first :class:`Row` if ``n`` is ``None.``
+ """Returns the first ``n`` rows.
+
+ :param n: int, default 1. Number of rows to return.
+ :return: If n is greater than 1, return a list of :class:`Row`.
+ If n is 1, return a single Row.
>>> df.head()
Row(age=2, name=u'Alice')
@@ -745,7 +746,7 @@ def selectExpr(self, *expr):
This is a variant of :func:`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
- [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
+ [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)]
"""
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
@@ -925,8 +926,7 @@ def dropDuplicates(self, subset=None):
@since("1.3.1")
def dropna(self, how='any', thresh=None, subset=None):
"""Returns a new :class:`DataFrame` omitting rows with null values.
-
- This is an alias for ``na.drop()``.
+ :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other.
:param how: 'any' or 'all'.
If 'any', drop a row if it contains any nulls.
@@ -936,13 +936,6 @@ def dropna(self, how='any', thresh=None, subset=None):
This overwrites the `how` parameter.
:param subset: optional list of column names to consider.
- >>> df4.dropna().show()
- +---+------+-----+
- |age|height| name|
- +---+------+-----+
- | 10| 80|Alice|
- +---+------+-----+
-
>>> df4.na.drop().show()
+---+------+-----+
|age|height| name|
@@ -968,6 +961,7 @@ def dropna(self, how='any', thresh=None, subset=None):
@since("1.3.1")
def fillna(self, value, subset=None):
"""Replace null values, alias for ``na.fill()``.
+ :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other.
:param value: int, long, float, string, or dict.
Value to replace null values with.
@@ -979,7 +973,7 @@ def fillna(self, value, subset=None):
For example, if `value` is a string, and subset contains a non-string column,
then the non-string column is simply ignored.
- >>> df4.fillna(50).show()
+ >>> df4.na.fill(50).show()
+---+------+-----+
|age|height| name|
+---+------+-----+
@@ -989,16 +983,6 @@ def fillna(self, value, subset=None):
| 50| 50| null|
+---+------+-----+
- >>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
- +---+------+-------+
- |age|height| name|
- +---+------+-------+
- | 10| 80| Alice|
- | 5| null| Bob|
- | 50| null| Tom|
- | 50| null|unknown|
- +---+------+-------+
-
>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
+---+------+-------+
|age|height| name|
@@ -1030,6 +1014,8 @@ def fillna(self, value, subset=None):
@since(1.4)
def replace(self, to_replace, value, subset=None):
"""Returns a new :class:`DataFrame` replacing a value with another value.
+ :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
+ aliases of each other.
:param to_replace: int, long, float, string, or list.
Value to be replaced.
@@ -1045,7 +1031,7 @@ def replace(self, to_replace, value, subset=None):
For example, if `value` is a string, and subset contains a non-string column,
then the non-string column is simply ignored.
- >>> df4.replace(10, 20).show()
+ >>> df4.na.replace(10, 20).show()
+----+------+-----+
| age|height| name|
+----+------+-----+
@@ -1055,7 +1041,7 @@ def replace(self, to_replace, value, subset=None):
|null| null| null|
+----+------+-----+
- >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+ >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+----+------+----+
| age|height|name|
+----+------+----+
@@ -1106,9 +1092,9 @@ def replace(self, to_replace, value, subset=None):
@since(1.4)
def corr(self, col1, col2, method=None):
"""
- Calculates the correlation of two columns of a DataFrame as a double value. Currently only
- supports the Pearson Correlation Coefficient.
- :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
+ Calculates the correlation of two columns of a DataFrame as a double value.
+ Currently only supports the Pearson Correlation Coefficient.
+ :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other.
:param col1: The name of the first column
:param col2: The name of the second column
@@ -1170,6 +1156,9 @@ def freqItems(self, cols, support=None):
"http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
:func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases.
+ .. note:: This function is meant for exploratory data analysis, as we make no \
+ guarantee about the backward compatibility of the schema of the resulting DataFrame.
+
:param cols: Names of the columns to calculate frequent items for as a list or tuple of
strings.
:param support: The frequency with which to consider an item 'frequent'. Default is 1%.
@@ -1214,15 +1203,30 @@ def withColumnRenamed(self, existing, new):
@since(1.4)
@ignore_unicode_prefix
- def drop(self, colName):
+ def drop(self, col):
"""Returns a new :class:`DataFrame` that drops the specified column.
- :param colName: string, name of the column to drop.
+ :param col: a string name of the column to drop, or a
+ :class:`Column` to drop.
>>> df.drop('age').collect()
[Row(name=u'Alice'), Row(name=u'Bob')]
+
+ >>> df.drop(df.age).collect()
+ [Row(name=u'Alice'), Row(name=u'Bob')]
+
+ >>> df.join(df2, df.name == df2.name, 'inner').drop(df.name).collect()
+ [Row(age=5, height=85, name=u'Bob')]
+
+ >>> df.join(df2, df.name == df2.name, 'inner').drop(df2.name).collect()
+ [Row(age=5, name=u'Bob', height=85)]
"""
- jdf = self._jdf.drop(colName)
+ if isinstance(col, basestring):
+ jdf = self._jdf.drop(col)
+ elif isinstance(col, Column):
+ jdf = self._jdf.drop(col._jc)
+ else:
+ raise TypeError("col should be a string or a Column")
return DataFrame(jdf, self.sql_ctx)
@since(1.3)
@@ -1239,7 +1243,10 @@ def toPandas(self):
import pandas as pd
return pd.DataFrame.from_records(self.collect(), columns=self.columns)
+ ##########################################################################################
# Pandas compatibility
+ ##########################################################################################
+
groupby = groupBy
drop_duplicates = dropDuplicates
@@ -1259,6 +1266,8 @@ def _to_scala_map(sc, jm):
class DataFrameNaFunctions(object):
"""Functionality for working with missing data in :class:`DataFrame`.
+
+ .. versionadded:: 1.4
"""
def __init__(self, df):
@@ -1274,9 +1283,16 @@ def fill(self, value, subset=None):
fill.__doc__ = DataFrame.fillna.__doc__
+ def replace(self, to_replace, value, subset=None):
+ return self.df.replace(to_replace, value, subset)
+
+ replace.__doc__ = DataFrame.replace.__doc__
+
class DataFrameStatFunctions(object):
"""Functionality for statistic functions with :class:`DataFrame`.
+
+ .. versionadded:: 1.4
"""
def __init__(self, df):
@@ -1316,6 +1332,8 @@ def _test():
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2),
+ Row(name='Bob', age=5)]).toDF()
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b6fd413bec7db..f036644acc961 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -43,6 +43,44 @@ def _df(self, jdf):
from pyspark.sql.dataframe import DataFrame
return DataFrame(jdf, self._sqlContext)
+ @since(1.4)
+ def format(self, source):
+ """Specifies the input data source format.
+
+ :param source: string, name of the data source, e.g. 'json', 'parquet'.
+
+ >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json')
+ >>> df.dtypes
+ [('age', 'bigint'), ('name', 'string')]
+
+ """
+ self._jreader = self._jreader.format(source)
+ return self
+
+ @since(1.4)
+ def schema(self, schema):
+ """Specifies the input schema.
+
+ Some data sources (e.g. JSON) can infer the input schema automatically from data.
+ By specifying the schema here, the underlying data source can skip the schema
+ inference step, and thus speed up data loading.
+
+ :param schema: a StructType object
+ """
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+ self._jreader = self._jreader.schema(jschema)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """Adds input options for the underlying data source.
+ """
+ for k in options:
+ self._jreader = self._jreader.option(k, options[k])
+ return self
+
@since(1.4)
def load(self, path=None, format=None, schema=None, **options):
"""Loads data from a data source and returns it as a :class`DataFrame`.
@@ -51,21 +89,20 @@ def load(self, path=None, format=None, schema=None, **options):
:param format: optional string for format of the data source. Default to 'parquet'.
:param schema: optional :class:`StructType` for the input schema.
:param options: all other string options
+
+ >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned')
+ >>> df.dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
- jreader = self._jreader
if format is not None:
- jreader = jreader.format(format)
+ self.format(format)
if schema is not None:
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jreader = jreader.schema(jschema)
- for k in options:
- jreader = jreader.option(k, options[k])
+ self.schema(schema)
+ self.options(**options)
if path is not None:
- return self._df(jreader.load(path))
+ return self._df(self._jreader.load(path))
else:
- return self._df(jreader.load())
+ return self._df(self._jreader.load())
@since(1.4)
def json(self, path, schema=None):
@@ -79,47 +116,25 @@ def json(self, path, schema=None):
:param path: string, path to the JSON dataset.
:param schema: an optional :class:`StructType` for the input schema.
- >>> import tempfile, shutil
- >>> jsonFile = tempfile.mkdtemp()
- >>> shutil.rmtree(jsonFile)
- >>> with open(jsonFile, 'w') as f:
- ... f.writelines(jsonStrings)
- >>> df1 = sqlContext.read.json(jsonFile)
- >>> df1.printSchema()
- root
- |-- field1: long (nullable = true)
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field4: long (nullable = true)
-
- >>> from pyspark.sql.types import *
- >>> schema = StructType([
- ... StructField("field2", StringType()),
- ... StructField("field3",
- ... StructType([StructField("field5", ArrayType(IntegerType()))]))])
- >>> df2 = sqlContext.read.json(jsonFile, schema)
- >>> df2.printSchema()
- root
- |-- field2: string (nullable = true)
- |-- field3: struct (nullable = true)
- | |-- field5: array (nullable = true)
- | | |-- element: integer (containsNull = true)
+ >>> df = sqlContext.read.json('python/test_support/sql/people.json')
+ >>> df.dtypes
+ [('age', 'bigint'), ('name', 'string')]
+
"""
- if schema is None:
- jdf = self._jreader.json(path)
- else:
- jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
- jdf = self._jreader.schema(jschema).json(path)
- return self._df(jdf)
+ if schema is not None:
+ self.schema(schema)
+ return self._df(self._jreader.json(path))
@since(1.4)
def table(self, tableName):
"""Returns the specified table as a :class:`DataFrame`.
- >>> sqlContext.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlContext.read.table("table1")
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
+ :param tableName: string, name of the table.
+
+ >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df.registerTempTable('tmpTable')
+ >>> sqlContext.read.table('tmpTable').dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.table(tableName))
@@ -127,13 +142,9 @@ def table(self, tableName):
def parquet(self, *path):
"""Loads a Parquet file, returning the result as a :class:`DataFrame`.
- >>> import tempfile, shutil
- >>> parquetFile = tempfile.mkdtemp()
- >>> shutil.rmtree(parquetFile)
- >>> df.saveAsParquetFile(parquetFile)
- >>> df2 = sqlContext.read.parquet(parquetFile)
- >>> sorted(df.collect()) == sorted(df2.collect())
- True
+ >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+ >>> df.dtypes
+ [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
"""
return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path)))
@@ -195,40 +206,88 @@ def __init__(self, df):
self._jwrite = df._jdf.write()
@since(1.4)
- def save(self, path=None, format=None, mode="error", **options):
- """
- Saves the contents of the :class:`DataFrame` to a data source.
+ def mode(self, saveMode):
+ """Specifies the behavior when data or table already exists.
- The data source is specified by the ``format`` and a set of ``options``.
- If ``format`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
+ Options include:
* `append`: Append contents of this :class:`DataFrame` to existing data.
* `overwrite`: Overwrite existing data.
* `error`: Throw an exception if data already exists.
* `ignore`: Silently ignore this operation if data already exists.
+ >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
+ """
+ self._jwrite = self._jwrite.mode(saveMode)
+ return self
+
+ @since(1.4)
+ def format(self, source):
+ """Specifies the underlying output data source.
+
+ :param source: string, name of the data source, e.g. 'json', 'parquet'.
+
+ >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data'))
+ """
+ self._jwrite = self._jwrite.format(source)
+ return self
+
+ @since(1.4)
+ def options(self, **options):
+ """Adds output options for the underlying data source.
+ """
+ for k in options:
+ self._jwrite = self._jwrite.option(k, options[k])
+ return self
+
+ @since(1.4)
+ def partitionBy(self, *cols):
+ """Partitions the output by the given columns on the file system.
+
+ If specified, the output is laid out on the file system similar
+ to Hive's partitioning scheme.
+
+ :param cols: name of columns
+
+ >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+ cols = cols[0]
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+ return self
+
+ @since(1.4)
+ def save(self, path=None, format=None, mode="error", **options):
+ """Saves the contents of the :class:`DataFrame` to a data source.
+
+ The data source is specified by the ``format`` and a set of ``options``.
+ If ``format`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
+
:param path: the path in a Hadoop supported file system
:param format: the format used to save
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ :param mode: specifies the behavior of the save operation when data already exists.
+
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
:param options: all other string options
+
+ >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
+ self.format(format)
if path is None:
- jwrite.save()
+ self._jwrite.save()
else:
- jwrite.save(path)
+ self._jwrite.save(path)
+ @since(1.4)
def insertInto(self, tableName, overwrite=False):
- """
- Inserts the content of the :class:`DataFrame` to the specified table.
+ """Inserts the content of the :class:`DataFrame` to the specified table.
+
It requires that the schema of the class:`DataFrame` is the same as the
schema of the table.
@@ -238,8 +297,7 @@ def insertInto(self, tableName, overwrite=False):
@since(1.4)
def saveAsTable(self, name, format=None, mode="error", **options):
- """
- Saves the content of the :class:`DataFrame` as the specified table.
+ """Saves the content of the :class:`DataFrame` as the specified table.
In the case the table already exists, behavior of this function depends on the
save mode, specified by the `mode` function (default to throwing an exception).
@@ -256,72 +314,61 @@ def saveAsTable(self, name, format=None, mode="error", **options):
:param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
:param options: all other string options
"""
- jwrite = self._jwrite.mode(mode)
+ self.mode(mode).options(**options)
if format is not None:
- jwrite = jwrite.format(format)
- for k in options:
- jwrite = jwrite.option(k, options[k])
- return jwrite.saveAsTable(name)
+ self.format(format)
+ self._jwrite.saveAsTable(name)
@since(1.4)
def json(self, path, mode="error"):
- """
- Saves the content of the :class:`DataFrame` in JSON format at the
- specified path.
+ """Saves the content of the :class:`DataFrame` in JSON format at the specified path.
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
+ :param path: the path in any Hadoop supported file system
+ :param mode: specifies the behavior of the save operation when data already exists.
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
- :param path: the path in any Hadoop supported file system
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- return self._jwrite.mode(mode).json(path)
+ self._jwrite.mode(mode).json(path)
@since(1.4)
def parquet(self, path, mode="error"):
- """
- Saves the content of the :class:`DataFrame` in Parquet format at the
- specified path.
+ """Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
- Additionally, mode is used to specify the behavior of the save operation when
- data already exists in the data source. There are four modes:
+ :param path: the path in any Hadoop supported file system
+ :param mode: specifies the behavior of the save operation when data already exists.
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
- :param path: the path in any Hadoop supported file system
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data'))
"""
- return self._jwrite.mode(mode).parquet(path)
+ self._jwrite.mode(mode).parquet(path)
@since(1.4)
def jdbc(self, url, table, mode="error", properties={}):
- """
- Saves the content of the :class:`DataFrame` to a external database table
- via JDBC.
-
- In the case the table already exists in the external database,
- behavior of this function depends on the save mode, specified by the `mode`
- function (default to throwing an exception). There are four modes:
+ """Saves the content of the :class:`DataFrame` to a external database table via JDBC.
- * `append`: Append contents of this :class:`DataFrame` to existing data.
- * `overwrite`: Overwrite existing data.
- * `error`: Throw an exception if data already exists.
- * `ignore`: Silently ignore this operation if data already exists.
+ .. note:: Don't create too many partitions in parallel on a large cluster;\
+ otherwise Spark might crash your external database systems.
- :param url: a JDBC URL of the form `jdbc:subprotocol:subname`
+ :param url: a JDBC URL of the form ``jdbc:subprotocol:subname``
:param table: Name of the table in the external database.
- :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error)
+ :param mode: specifies the behavior of the save operation when data already exists.
+
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
:param properties: JDBC database connection arguments, a list of
- arbitrary string tag/value. Normally at least a
- "user" and "password" property should be included.
+ arbitrary string tag/value. Normally at least a
+ "user" and "password" property should be included.
"""
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
for k in properties:
@@ -331,24 +378,23 @@ def jdbc(self, url, table, mode="error", properties={}):
def _test():
import doctest
+ import os
+ import tempfile
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.readwriter
+
+ os.chdir(os.environ["SPARK_HOME"])
+
globs = pyspark.sql.readwriter.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
+
+ globs['tempfile'] = tempfile
+ globs['os'] = os
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
- globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
- .toDF(StructType([StructField('age', IntegerType()),
- StructField('name', StringType())]))
- jsonStrings = [
- '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
- '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
- '"field6":[{"field7": "row2"}]}',
- '{"field1" : null, "field2": "row3", '
- '"field3":{"field4":33, "field5": []}}'
- ]
- globs['jsonStrings'] = jsonStrings
+ globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
+
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5c53c3a8ed4f1..a6fce50c76c2b 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -100,6 +100,15 @@ def test_data_type_eq(self):
lt2 = pickle.loads(pickle.dumps(LongType()))
self.assertEquals(lt, lt2)
+ # regression test for SPARK-7978
+ def test_decimal_type(self):
+ t1 = DecimalType()
+ t2 = DecimalType(10, 2)
+ self.assertTrue(t2 is not t1)
+ self.assertNotEqual(t1, t2)
+ t3 = DecimalType(8)
+ self.assertNotEqual(t2, t3)
+
class SQLTests(ReusedPySparkTestCase):
@@ -122,6 +131,8 @@ def test_range(self):
self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
+ self.assertEqual(self.sqlCtx.range(-2).count(), 0)
+ self.assertEqual(self.sqlCtx.range(3).count(), 3)
def test_explode(self):
from pyspark.sql.functions import explode
@@ -744,8 +755,10 @@ def setUpClass(cls):
try:
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
except py4j.protocol.Py4JError:
+ cls.tearDownClass()
raise unittest.SkipTest("Hive is not available")
except TypeError:
+ cls.tearDownClass()
raise unittest.SkipTest("Hive is not available")
os.unlink(cls.tempdir.name)
_scala_HiveContext =\
diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py
similarity index 99%
rename from python/pyspark/sql/_types.py
rename to python/pyspark/sql/types.py
index 9e7e9f04bc35d..b6ec6137c9180 100644
--- a/python/pyspark/sql/_types.py
+++ b/python/pyspark/sql/types.py
@@ -97,8 +97,6 @@ class AtomicType(DataType):
"""An internal type used to represent everything that is not
null, UDTs, arrays, structs, and maps."""
- __metaclass__ = DataTypeSingleton
-
class NumericType(AtomicType):
"""Numeric data types.
@@ -109,6 +107,8 @@ class IntegralType(NumericType):
"""Integral data types.
"""
+ __metaclass__ = DataTypeSingleton
+
class FractionalType(NumericType):
"""Fractional data types.
@@ -119,26 +119,36 @@ class StringType(AtomicType):
"""String data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class BinaryType(AtomicType):
"""Binary (byte array) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class BooleanType(AtomicType):
"""Boolean data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class DateType(AtomicType):
"""Date (datetime.date) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class TimestampType(AtomicType):
"""Timestamp (datetime.datetime) data type.
"""
+ __metaclass__ = DataTypeSingleton
+
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
@@ -172,11 +182,15 @@ class DoubleType(FractionalType):
"""Double data type, representing double precision floats.
"""
+ __metaclass__ = DataTypeSingleton
+
class FloatType(FractionalType):
"""Float data type, representing single precision floats.
"""
+ __metaclass__ = DataTypeSingleton
+
class ByteType(IntegralType):
"""Byte data type, i.e. a signed integer in a single byte.
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 0a0e006bdf83a..c74745c726a0c 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -32,7 +32,6 @@ def _to_java_cols(cols):
class Window(object):
-
"""
Utility functions for defining window in DataFrames.
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 33ea8c9293d74..57049beea4dba 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -41,8 +41,8 @@
class PySparkStreamingTestCase(unittest.TestCase):
- timeout = 4 # seconds
- duration = .2
+ timeout = 10 # seconds
+ duration = .5
@classmethod
def setUpClass(cls):
@@ -379,13 +379,13 @@ def func(dstream):
class WindowFunctionTests(PySparkStreamingTestCase):
- timeout = 5
+ timeout = 15
def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.window(.6, .2).count()
+ return dstream.window(1.5, .5).count()
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -394,7 +394,7 @@ def test_count_by_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
- return dstream.countByWindow(.6, .2)
+ return dstream.countByWindow(1.5, .5)
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
@@ -403,7 +403,7 @@ def test_count_by_window_large(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByWindow(1, .2)
+ return dstream.countByWindow(2.5, .5)
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)
@@ -412,7 +412,7 @@ def test_count_by_value_and_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
- return dstream.countByValueAndWindow(1, .2)
+ return dstream.countByValueAndWindow(2.5, .5)
expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
self._test_func(input, func, expected)
@@ -421,7 +421,7 @@ def test_group_by_key_and_window(self):
input = [[('a', i)] for i in range(5)]
def func(dstream):
- return dstream.groupByKeyAndWindow(.6, .2).mapValues(list)
+ return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list)
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
@@ -615,7 +615,6 @@ def test_kafka_stream(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
"test-streaming-consumer", {topic: 1},
@@ -631,7 +630,6 @@ def test_kafka_direct_stream(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
self._validateStreamResult(sendData, stream)
@@ -646,7 +644,6 @@ def test_kafka_direct_stream_from_offset(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
self._validateStreamResult(sendData, stream)
@@ -661,7 +658,6 @@ def test_kafka_rdd(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
self._validateRddResult(sendData, rdd)
@@ -677,7 +673,6 @@ def test_kafka_rdd_with_leaders(self):
self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
- self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
self._validateRddResult(sendData, rdd)
diff --git a/python/run-tests b/python/run-tests
index ffde2fb24b369..4468fdb3f267e 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -57,54 +57,57 @@ function run_test() {
function run_core_tests() {
echo "Run core tests ..."
- run_test "pyspark/rdd.py"
- run_test "pyspark/context.py"
- run_test "pyspark/conf.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
- run_test "pyspark/serializers.py"
- run_test "pyspark/profiler.py"
- run_test "pyspark/shuffle.py"
- run_test "pyspark/tests.py"
+ run_test "pyspark.rdd"
+ run_test "pyspark.context"
+ run_test "pyspark.conf"
+ run_test "pyspark.broadcast"
+ run_test "pyspark.accumulators"
+ run_test "pyspark.serializers"
+ run_test "pyspark.profiler"
+ run_test "pyspark.shuffle"
+ run_test "pyspark.tests"
}
function run_sql_tests() {
echo "Run sql tests ..."
- run_test "pyspark/sql/_types.py"
- run_test "pyspark/sql/context.py"
- run_test "pyspark/sql/column.py"
- run_test "pyspark/sql/dataframe.py"
- run_test "pyspark/sql/group.py"
- run_test "pyspark/sql/functions.py"
- run_test "pyspark/sql/tests.py"
+ run_test "pyspark.sql.types"
+ run_test "pyspark.sql.context"
+ run_test "pyspark.sql.column"
+ run_test "pyspark.sql.dataframe"
+ run_test "pyspark.sql.group"
+ run_test "pyspark.sql.functions"
+ run_test "pyspark.sql.readwriter"
+ run_test "pyspark.sql.window"
+ run_test "pyspark.sql.tests"
}
function run_mllib_tests() {
echo "Run mllib tests ..."
- run_test "pyspark/mllib/classification.py"
- run_test "pyspark/mllib/clustering.py"
- run_test "pyspark/mllib/evaluation.py"
- run_test "pyspark/mllib/feature.py"
- run_test "pyspark/mllib/fpm.py"
- run_test "pyspark/mllib/linalg.py"
- run_test "pyspark/mllib/rand.py"
- run_test "pyspark/mllib/recommendation.py"
- run_test "pyspark/mllib/regression.py"
- run_test "pyspark/mllib/stat/_statistics.py"
- run_test "pyspark/mllib/tree.py"
- run_test "pyspark/mllib/util.py"
- run_test "pyspark/mllib/tests.py"
+ run_test "pyspark.mllib.classification"
+ run_test "pyspark.mllib.clustering"
+ run_test "pyspark.mllib.evaluation"
+ run_test "pyspark.mllib.feature"
+ run_test "pyspark.mllib.fpm"
+ run_test "pyspark.mllib.linalg"
+ run_test "pyspark.mllib.random"
+ run_test "pyspark.mllib.recommendation"
+ run_test "pyspark.mllib.regression"
+ run_test "pyspark.mllib.stat._statistics"
+ run_test "pyspark.mllib.stat.KernelDensity"
+ run_test "pyspark.mllib.tree"
+ run_test "pyspark.mllib.util"
+ run_test "pyspark.mllib.tests"
}
function run_ml_tests() {
echo "Run ml tests ..."
- run_test "pyspark/ml/feature.py"
- run_test "pyspark/ml/classification.py"
- run_test "pyspark/ml/recommendation.py"
- run_test "pyspark/ml/regression.py"
- run_test "pyspark/ml/tuning.py"
- run_test "pyspark/ml/tests.py"
- run_test "pyspark/ml/evaluation.py"
+ run_test "pyspark.ml.feature"
+ run_test "pyspark.ml.classification"
+ run_test "pyspark.ml.recommendation"
+ run_test "pyspark.ml.regression"
+ run_test "pyspark.ml.tuning"
+ run_test "pyspark.ml.tests"
+ run_test "pyspark.ml.evaluation"
}
function run_streaming_tests() {
@@ -124,8 +127,8 @@ function run_streaming_tests() {
done
export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
- run_test "pyspark/streaming/util.py"
- run_test "pyspark/streaming/tests.py"
+ run_test "pyspark.streaming.util"
+ run_test "pyspark.streaming.tests"
}
echo "Running PySpark tests. Output is in python/$LOG_FILE."
diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata
new file mode 100644
index 0000000000000..7ef2320651dee
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_common_metadata differ
diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata
new file mode 100644
index 0000000000000..78a1ca7d38279
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_metadata differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc
new file mode 100644
index 0000000000000..e93f42ed6f350
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet
new file mode 100644
index 0000000000000..461c382937ecd
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc
new file mode 100644
index 0000000000000..b63c4d6d1e1dc
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc
new file mode 100644
index 0000000000000..5bc0ebd713563
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet
new file mode 100644
index 0000000000000..62a63915beac2
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet
new file mode 100644
index 0000000000000..67665a7b55da6
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc
new file mode 100644
index 0000000000000..ae94a15d08c81
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet
new file mode 100644
index 0000000000000..6cb8538aa8904
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc
new file mode 100644
index 0000000000000..58d9bb5fc5883
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc differ
diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet
new file mode 100644
index 0000000000000..9b00805481e7b
Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet differ
diff --git a/python/test_support/sql/people.json b/python/test_support/sql/people.json
new file mode 100644
index 0000000000000..50a859cbd7ee8
--- /dev/null
+++ b/python/test_support/sql/people.json
@@ -0,0 +1,3 @@
+{"name":"Michael"}
+{"name":"Andy", "age":30}
+{"name":"Justin", "age":19}
diff --git a/repl/pom.xml b/repl/pom.xml
index 03053b4c3b287..85f7bc8ac1024 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../pom.xml
@@ -48,6 +48,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-bagel_${scala.binary.version}
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 934daaeaafca1..50fd43a418bca 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -22,13 +22,12 @@ import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.util.Utils
-class ReplSuite extends FunSuite {
+class ReplSuite extends SparkFunSuite {
def runInterpreter(master: String, input: String): String = {
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 14f5e9ed4f25e..9ecc7c229e38a 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.tools.nsc.interpreter.SparkILoop
-import org.scalatest.FunSuite
import org.apache.commons.lang3.StringEscapeUtils
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils
-class ReplSuite extends FunSuite {
+class ReplSuite extends SparkFunSuite {
def runInterpreter(master: String, input: String): String = {
val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath"
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index c709cde740748..a58eda12b1120 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -25,7 +25,6 @@ import scala.language.implicitConversions
import scala.language.postfixOps
import org.scalatest.BeforeAndAfterAll
-import org.scalatest.FunSuite
import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar
@@ -35,7 +34,7 @@ import org.apache.spark._
import org.apache.spark.util.Utils
class ExecutorClassLoaderSuite
- extends FunSuite
+ extends SparkFunSuite
with BeforeAndAfterAll
with MockitoSugar
with Logging {
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 7168d5b2a8e26..d6f927b6fa803 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -14,25 +14,41 @@
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
-
-
-
-
-
-
+
- Scalastyle standard configuration
-
-
-
-
-
-
-
-
- Scalastyle standard configuration
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
- true
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+ ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+
+
+
+ ^FunSuite[A-Za-z]*$
+ Tests must extend org.apache.spark.SparkFunSuite instead.
+
+
+
+
+
+
+
+
+ ^println$
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 800>
+
+
+
+
+ 30
+
+
+
+
+ 10
+
+
+
+
+ 50
+
+
+
+
+
+
+
+
+
+
+ -1,0,1,2,3
+
+
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5c322d032d474..f4b1cc3a4ffe7 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.4.0-SNAPSHOT
+ 1.5.0-SNAPSHOT../../pom.xml
@@ -36,10 +36,6 @@
-
- org.scala-lang
- scala-compiler
- org.scala-langscala-reflect
@@ -50,6 +46,13 @@
spark-core_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+ org.apache.sparkspark-unsafe_${scala.binary.version}
@@ -60,6 +63,11 @@
scalacheck_${scala.binary.version}test
+
+ org.codehaus.janino
+ janino
+ 2.7.8
+ target/scala-${scala.binary.version}/classes
@@ -101,13 +109,6 @@
!scala-2.11
-
-
- org.scalamacros
- quasiquotes_${scala.binary.version}
- ${scala.macros.version}
-
-
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index bb546b3086b33..ec97fe603c44f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,23 +17,25 @@
package org.apache.spark.sql.catalyst.expressions;
-import scala.collection.Map;
+import javax.annotation.Nullable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
import scala.collection.Seq;
import scala.collection.mutable.ArraySeq;
-import javax.annotation.Nullable;
-import java.math.BigDecimal;
-import java.sql.Date;
-import java.util.*;
-
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.BaseMutableRow;
import org.apache.spark.sql.types.DataType;
-import static org.apache.spark.sql.types.DataTypes.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UTF8String;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import static org.apache.spark.sql.types.DataTypes.*;
+
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
*
@@ -49,7 +51,7 @@
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
-public final class UnsafeRow implements MutableRow {
+public final class UnsafeRow extends BaseMutableRow {
private Object baseObject;
private long baseOffset;
@@ -227,21 +229,11 @@ public int size() {
return numFields;
}
- @Override
- public int length() {
- return size();
- }
-
@Override
public StructType schema() {
return schema;
}
- @Override
- public Object apply(int i) {
- return get(i);
- }
-
@Override
public Object get(int i) {
assertIndexIsValid(i);
@@ -339,60 +331,7 @@ public String getString(int i) {
return getUTF8String(i).toString();
}
- @Override
- public BigDecimal getDecimal(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Date getDate(int i) {
- throw new UnsupportedOperationException();
- }
- @Override
- public Seq getSeq(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public List getList(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Map getMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public scala.collection.immutable.Map getValuesMap(Seq fieldNames) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public java.util.Map getJavaMap(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Row getStruct(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public T getAs(int i) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public T getAs(String fieldName) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int fieldIndex(String name) {
- throw new UnsupportedOperationException();
- }
@Override
public Row copy() {
@@ -412,24 +351,4 @@ public Seq